utils.py 15.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

import random
from dataclasses import dataclass
6
from typing import TypeAlias
7

8
import numpy as np
9
10
11
12
13
import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from vllm.engine.arg_utils import EngineArgs
from vllm.v1.engine import EngineCoreOutput, FinishReason
14
from vllm.v1.metrics.stats import PrefillStats
15
16
from vllm.v1.outputs import LogprobsLists, LogprobsTensors

17
GeneralTokenizerType: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast
18
19
20
21
22
23

# Number of sample logprobs to request when testing sample logprobs
NUM_SAMPLE_LOGPROBS_UNDER_TEST = 5
# Number of prompt logprobs to request when testing prompt logprobs
NUM_PROMPT_LOGPROBS_UNDER_TEST = 7

24
TOKENIZER_NAME = "meta-llama/Llama-3.2-1B"
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

FULL_STRINGS = [
    "My name is Robert from Neural Magic and I love working on vLLM so much!",
    "Red Hat is the best open source company by far across Linux, K8s, and AI.",
    "Nick is the name of my brother in addition to my colleague from Red Hat.",
]
STOP_STRINGS = ["I love working on", "company by far", "brother in"]
PROMPT_LEN = 5

random.seed(42)


def _create_random_top_logprob_test_vector(
    num_logprobs: int,
    lower: float,
    upper: float,
) -> torch.Tensor:
    """Create a random vector of top logprob float values.
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    Use to create fake sample logprobs for testing.

    Note that a real production scenario would require
    logprobs to be sorted in descending order, something
    which is omitted in this function.

    Args:
      num_logprobs: number of top logprobs
      lower: lower range of logprob float values
      upper: upper range of logprob float values

    Returns:
      1D length-`num_logprobs` torch Tensor of float logprob values
    """
    return torch.rand(num_logprobs) * (upper - lower) + lower


def _create_random_top_logprob_test_matrix(
62
    shape: tuple,
63
64
65
66
    lower: float,
    upper: float,
) -> torch.Tensor:
    """Create a random matrix of top logprob float values.
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    Use to create fake prompt logprobs for testing.

    Note that a real production scenario would require
    logprobs to be sorted in descending order along rows,
    something which is omitted in this function.

    Args:
      shape: (num_tokens,num_logprobs) tuple representing
             matrix shape
      lower: lower range of logprob float values
      upper: upper range of logprob float values

    Returns:
      2D num_tokens x num_logprobs torch Tensor of float logprob values
    """
    return torch.rand(*shape) * (upper - lower) + lower


def _create_random_top_token_test_vector(
87
88
89
90
91
92
    num_logprobs: int,
    lower: int,
    upper: int,
    sampled_token_id: int,
    adjust_num_logprobs: bool = True,
) -> tuple[torch.Tensor, int]:
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    """Create a random vector of top logprob token indices

    Use to create fake sample logprobs for testing. The sampled token
    ID must always be one of the top logprobs, which this dummy test
    vector generator enforces. OpenAI API
    compatible engines must be able to return an additional sample
    logprob for the sampled token if the sampled token was not
    among the top sample logprobs; `adjust_num_logprobs` emulates
    this behavior by increasing the vector length by 1 if
    `adjust_num_logprobs` is set.

    Args:
      num_logprobs: number of top logprobs
      lower: lower range of token ids
      upper: upper range of token ids
      sampled_token_id: the token actually sampled
      adjust_num_logprobs: if True, emulate situation where sampled
                           token logprob must be injected into top
                           logprobs

    Returns:
      1D length-x torch Tensor of token ids where x is
      `num_logprobs+1` if `adjust_num_logprobs` and
      `num_logprobs` otherwise
      sampled_token_rank: the rank of sampled_token_id in the vocab
                          vector when sorted in descending order by
                          logprob
    """

    # Calculate the final number of logprobs required
    total_logprobs = num_logprobs + 1 if adjust_num_logprobs else num_logprobs

    # Generate random indices using torch
    choice_tensor = torch.randperm(upper - lower)[:total_logprobs] + lower

    # Ensure the sampled token ID is included in the tensor
    choice_tensor[0] = sampled_token_id

    # Check if the sampled_token_id occurs in choice_tensor[1:]
    if sampled_token_id in choice_tensor[1:]:
133
134
135
        sampled_token_rank = (
            (choice_tensor[1:] == sampled_token_id).nonzero(as_tuple=True)[0].item()
        )
136
137
138
139
140
141
142
143
    else:
        # If not found, assign a random int between num_logprobs and 50700
        sampled_token_rank = random.randint(num_logprobs, 50700)

    return choice_tensor, sampled_token_rank


def _create_random_top_token_test_matrix(
144
    shape: tuple[int, int],
145
146
    lower: int,
    upper: int,
147
148
    tokens_list: list[int],
) -> tuple[torch.Tensor, torch.Tensor]:
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    """Create a random matrix of top logprob token indices

    Use to create fake prompt logprobs for testing.

    Token ids are generated randomly and sampled without
    replacement.

    Args:
      shape: (num_tokens, num_logprobs) tuple representing
             matrix shape
      lower: lower range of token ids
      upper: upper range of token ids

    Returns:
163
      tuple containing:
164
165
166
167
168
169
170
      - 2D num_tokens x num_logprobs+1 torch Tensor of token ids
      - 1D tensor of ranks of prompt tokens in their respective
        rows, or random values
    """
    num_elements = shape[0] * shape[1]
    choice_tensor = torch.randperm(upper - lower)[:num_elements] + lower
    matrix = torch.cat(
171
172
173
174
175
176
        (
            torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1),
            choice_tensor.view(shape),
        ),
        dim=1,
    )
177
178
179
180
181
182
183

    # Initialize the tensor for storing the ranks
    prompt_token_ranks = torch.empty(shape[0], dtype=torch.int)

    # Iterate over each row to check presence of
    # tokens_list[rdx] and determine its index
    for rdx in range(shape[0]):
184
        row = matrix[rdx, 1:]  # Skip the first column as it contains the token list
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        token_index = (row == tokens_list[rdx]).nonzero(as_tuple=True)[0]
        if token_index.numel() > 0:
            prompt_token_ranks[rdx] = token_index.item()
        else:
            prompt_token_ranks[rdx] = random.randint(shape[1], 50700)

    return matrix, prompt_token_ranks


def decode_token(
    tok_id: int,
    tokenizer: PreTrainedTokenizer,
) -> str:
    """Reproduce the process of detokenizing a token for testing purposes.

    Args:
      tok_id: token id to detokenize
      tokenizer: tokenizer to use for detokenization

    Returns:
      string representation of token
    """
    return tokenizer.convert_ids_to_tokens(tok_id)


def generate_dummy_sample_logprobs(
211
    sampled_tokens_list: list,
212
213
    num_logprobs: int,
    tokenizer: PreTrainedTokenizer,
214
) -> list[tuple[list[int], list[float], int]]:
215
216
217
218
219
220
221
222
223
224
225
    """Generate dummy sample logprobs

    Generate a test data structure which imitates the list of sample logprobs
    which would be assembled in the engine core during decode phase.

    Args:
      sampled_tokens_list: list of sampled tokens
      num_logprobs: return `num_logprobs` or `num_logprobs+1` logprobs per token
      tokenizer: model tokenizer to use for detokenization

    Returns
226
      list of (top token ids vector, logprobs vector, sampled token rank)
227
228
229
230
231
232
233
234
235
236
237
      Python lists tuples; in each tuple the logprobs and top token ids
      vectors have the same length which is either `num_logprobs` or
      `num_logprobs+1`. Sampled token rank is the rank (index+1) of the
      sampled token within the vocab vector when sorted by logprob in
      descending order.
    """
    res = []
    for sampled_token_id in sampled_tokens_list:
        (
            token_vector,
            sampled_token_rank,
238
239
240
        ) = _create_random_top_token_test_vector(
            num_logprobs, 0, len(tokenizer.vocab) - 1, sampled_token_id
        )
241
242

        res.append(
243
244
245
246
247
248
            (
                token_vector,
                _create_random_top_logprob_test_vector(num_logprobs + 1, -100, 0),
                sampled_token_rank,
            )
        )
249
250
251

    # Convert tensors in the list tuples to Python lists
    res_list_format = [
252
        (log_probs_tensor.tolist(), token_ids_tensor.tolist(), sampled_token_rank)
253
254
255
256
257
258
259
        for log_probs_tensor, token_ids_tensor, sampled_token_rank in res
    ]

    return res_list_format


def generate_dummy_prompt_logprobs_tensors(
260
    prompt_tokens_list: list,
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    num_logprobs: int,
    tokenizer: PreTrainedTokenizer,
) -> LogprobsTensors:
    """Generate dummy prompt logprobs tensors

    Generate a test data structure which imitates the torch Tensors of prompt
    logprobs which would be assembled in the engine core during chunked
    prefill.

    Args:
      prompt_tokens_list: list of prompt tokens
      num_logprobs: return `num_logprobs` logprobs per token
      tokenizer: model tokenizer to use for detokenization

    Returns
276
      Single tuple of (logprobs matrix, top token ids matrix) torch Tensor,
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
      where both matrices have dimensions
      num_prompt_tokens x num_logprobs
    """
    # For now, assume the whole prompt is processed in one chunk; thus,
    # the number of non-`None` prompt logprobs is `len(prompt_tokens_list)-1`.
    # Prior to injecting `None` at the beginning of prompt logprobs (which
    # happens later in the detokenizer, not here), the prompt logprobs in
    # the ith position are predicting the probability distribution of the
    # prompt token in (i+1)st position. Thus, we concat
    # `prompt_tokens_list[1:]` to the dummy token ids, just as the engine
    # would.
    num_prompt_logprobs = len(prompt_tokens_list) - 1
    (
        token_vector,
        prompt_token_ranks,
    ) = _create_random_top_token_test_matrix(
293
294
295
296
297
        (num_prompt_logprobs, num_logprobs),
        0,
        len(tokenizer.vocab) - 1,
        prompt_tokens_list[1:],
    )
298
299
300
    return LogprobsTensors(
        token_vector,
        _create_random_top_logprob_test_matrix(
301
302
303
304
            (num_prompt_logprobs, num_logprobs + 1), -100, 0
        ),
        prompt_token_ranks,
    )
305
306
307
308
309


@dataclass
class DummyOutputProcessorTestVectors:
    """Dummy test vectors for output processor tests"""
310

311
312
    tokenizer: GeneralTokenizerType
    vllm_config: EngineArgs
313
314
315
    full_tokens: list[list[int]]  # Prompt + generated tokens
    prompt_tokens: list[list[int]]
    generation_tokens: list[list[int]]
316
317
    # Each request is associated with a tuple of
    # (top tokens, top logprobs, ranks) prompt logprobs tensors
318
    prompt_logprobs: list[LogprobsTensors]
319
320
321
    # Each request is associated with a sample logprobs; a request's
    # sample logprobs are a list of (top tokens, top logprobs, ranks)
    # sample logprobs tensors at each sequence position
322
323
324
325
    generation_logprobs: list[list[tuple[list[int], list[float], int]]]
    prompt_strings: list[str]
    prompt_strings_len: list[int]
    generation_strings: list[str]
326
327
328
329
330
331
332


class MockEngineCore:
    """Mock engine core outputs form premade tokens lists."""

    def __init__(
        self,
333
        tokens_list: list[list[int]],
334
        prompts_list: list[list[int]],
335
336
337
        # For each request, for each sampled token offset,
        # a tuple of
        # (list of topk token ids, list of sample logprob vals, rank)
338
339
        generated_logprobs_raw: list[list[tuple[list[int], list[float], int]]]
        | None = None,
340
341
342
343
        # For each request, a tuple of
        # (prompt logprob val matrix, prompt logprob tok id matrix);
        # each matrix has dimensions
        # (num prompt toks) x (num prompt logprobs+1)
344
345
346
        prompt_logprobs_raw: list[LogprobsTensors] | None = None,
        eos_token_id: int | None = None,
        stop_token_ids: list[int] | None = None,
347
        request_ids: list[str] | None = None,
348
    ) -> None:
349
        self.num_requests = len(tokens_list)
350
        self.tokens_list = tokens_list
351
        self.prompts_list = prompts_list
352
353
354
355
        self.generated_logprobs_raw = generated_logprobs_raw
        self.do_logprobs = generated_logprobs_raw is not None
        self.prompt_logprobs_raw = prompt_logprobs_raw
        self.do_prompt_logprobs = prompt_logprobs_raw is not None
356
        self.request_finished = [False for _ in range(self.num_requests)]
357
        self.request_token_idx = [0 for _ in range(self.num_requests)]
358
359
        self.eos_token_id = eos_token_id
        self.stop_token_ids = stop_token_ids
360
361
362
363
364
        self.request_ids = (
            request_ids
            if request_ids is not None
            else [f"request-{i}" for i in range(self.num_requests)]
        )
365

366
    def get_outputs(self, num_active: int = -1) -> list[EngineCoreOutput]:
367
368
369
370
        do_logprobs = self.do_logprobs
        do_prompt_logprobs = self.do_prompt_logprobs

        outputs = []
371
372
373
374
375
        for req_idx, (token_ids, prompt_token_ids) in enumerate(
            zip(self.tokens_list, self.prompts_list)
        ):
            if num_active != -1 and req_idx >= num_active:
                break
376
            if not self.request_finished[req_idx]:
377
                token_idx = self.request_token_idx[req_idx]
378
379
380
                if do_logprobs:
                    assert self.generated_logprobs_raw is not None
                    (logprobs_token_ids_, logprobs_, sampled_token_ranks_) = (
381
382
                        self.generated_logprobs_raw[req_idx][token_idx]
                    )
383
                    logprobs = LogprobsLists(
384
385
386
                        np.array([logprobs_token_ids_]),
                        np.array([logprobs_]),
                        np.array([sampled_token_ranks_]),
387
388
389
390
                    )
                else:
                    logprobs = None
                if do_prompt_logprobs:
391
                    if token_idx == 0:
392
393
394
395
396
397
                        assert self.prompt_logprobs_raw is not None
                        prompt_logprobs = self.prompt_logprobs_raw[req_idx]
                    else:
                        prompt_logprobs = None
                else:
                    prompt_logprobs = None
398
399
400
401
402
403
404
405
406
407
408
409

                # Add prefill_stats on first output (prefill) for this request
                if token_idx == 0:
                    prefill_stats = PrefillStats()
                    prefill_stats.set(
                        num_prompt_tokens=len(prompt_token_ids),
                        num_local_cached_tokens=0,
                        num_external_cached_tokens=0,
                    )
                else:
                    prefill_stats = None

410
                new_token_id = token_ids[token_idx]
411
                output = EngineCoreOutput(
412
                    request_id=self.request_ids[req_idx],
413
                    new_token_ids=[new_token_id],
414
415
                    new_logprobs=logprobs,
                    new_prompt_logprobs_tensors=prompt_logprobs,
416
                    prefill_stats=prefill_stats,
417
418
                )
                if token_idx == len(token_ids) - 1:
419
420
                    output.finish_reason = FinishReason.LENGTH
                    self.request_finished[req_idx] = True
421
                if new_token_id == self.eos_token_id:
422
                    output.finish_reason = FinishReason.STOP
423
424
425
426
427
                    self.request_finished[req_idx] = True
                if new_token_id in (self.stop_token_ids or ()):
                    output.finish_reason = FinishReason.STOP
                    output.stop_reason = new_token_id
                    self.request_finished[req_idx] = True
428
429
                outputs.append(output)

430
431
                self.request_token_idx[req_idx] += 1

432
        return outputs