utils.py 16.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from __future__ import annotations

5
6
7
import hashlib
import importlib.metadata
import os
8
import tempfile
9
10
from typing import TYPE_CHECKING

11
import numpy as np
12
import regex as re
13
import torch
14
15
16
17
18
from cachetools import LRUCache
from diskcache import Cache

import vllm.envs as envs
from vllm.logger import init_logger
19
from vllm.utils.import_utils import LazyLoader
20
from vllm.utils.platform_utils import is_pin_memory_available
21
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
22
23
24

if TYPE_CHECKING:
    import outlines_core as oc
Harry Mellor's avatar
Harry Mellor committed
25
    import transformers.convert_slow_tokenizer as convert_slow_tokenizer
26
    import transformers.file_utils as file_utils
27
    import xgrammar as xgr
28

29
    from vllm.tokenizers import TokenizerLike
30
    from vllm.v1.worker.gpu_input_batch import InputBatch
31
else:
32
    xgr = LazyLoader("xgr", globals(), "xgrammar")
33
34
    oc = LazyLoader("oc", globals(), "outlines_core")
    file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
Harry Mellor's avatar
Harry Mellor committed
35
36
    convert_slow_tokenizer = LazyLoader(
        "convert_slow_tokenizer", globals(), "transformers.convert_slow_tokenizer"
37
38
    )

39

40
41
42
43
44
logger = init_logger(__name__)

CACHE = None


45
46
def apply_grammar_bitmask(
    scheduler_output: SchedulerOutput,
47
    grammar_output: GrammarOutput,
48
49
50
51
52
53
54
55
56
57
58
    input_batch: InputBatch,
    logits: torch.Tensor,
) -> None:
    """
    Apply grammar bitmask to output logits of the model with xgrammar function.

    Args:
        scheduler_output (SchedulerOutput): The result of engine scheduling.
        input_batch (InputBatch): The input of model runner.
        logits (torch.Tensor): The output logits of model forward.
    """
59
60
61
    # Serialization of np.ndarray is much more efficient than a tensor,
    # so we receive it in that format.
    grammar_bitmask = grammar_output.grammar_bitmask
62
63
64
65
66
67
68
69
70
71
72
73

    # We receive the structured output bitmask from the scheduler,
    # compacted to contain bitmasks only for structured output requests.
    # The order of the requests in the bitmask is not guaranteed to be the
    # same as the order of the requests in the gpu runner's batch. We need
    # to sort the bitmask to match the order of the requests used here.

    # Get the batch indices of the structured output requests.
    # Keep track of the number of speculative tokens scheduled for every
    # request in the batch, as the logit indices are offset by this amount.
    struct_out_req_batch_indices: dict[str, int] = {}
    cumulative_offset = 0
74
75
76
    spec_tokens = scheduler_output.scheduled_spec_decode_tokens
    struct_out_req_ids = set(grammar_output.structured_output_request_ids)
    for batch_index, req_id in enumerate(input_batch.req_ids):
77
        logit_index = batch_index + cumulative_offset
78
79
        cumulative_offset += len(spec_tokens.get(req_id, ()))
        if req_id in struct_out_req_ids:
80
81
82
83
84
            struct_out_req_batch_indices[req_id] = logit_index

    out_indices = []

    # Reorder the bitmask to match the order of the requests in the batch.
85
86
87
88
89
    sorted_bitmask = np.full(
        shape=(logits.shape[0], grammar_bitmask.shape[1]),
        fill_value=-1,
        dtype=grammar_bitmask.dtype,
    )
90
    cumulative_index = 0
91
    for req_id in grammar_output.structured_output_request_ids:
92
93
        num_spec_tokens = len(spec_tokens.get(req_id, ()))
        if (logit_idx := struct_out_req_batch_indices.get(req_id)) is not None:
94
            for i in range(1 + num_spec_tokens):
95
96
97
                bitmask_index = logit_idx + i
                sorted_bitmask[bitmask_index] = grammar_bitmask[cumulative_index + i]
                out_indices.append(bitmask_index)
98
        cumulative_index += 1 + num_spec_tokens
99
100
101
102
103

    # Copy async to device as tensor.
    grammar_bitmask = torch.from_numpy(sorted_bitmask).to(
        logits.device, non_blocking=True
    )
104
105
106
107
108
109

    # If the length of out indices and the logits have the same shape
    # we don't need to pass indices to the kernel,
    # since the bitmask is already aligned with the logits.
    skip_out_indices = len(out_indices) == logits.shape[0]

110
111
112
113
114
115
116
117
118
119
120
    if not logits.is_cpu:
        index_tensor = None
        if not skip_out_indices:
            # xgrammar expects a python list of indices but it will actually work with
            # a tensor. If we copy the tensor ourselves here we can do it in a
            # non_blocking manner and there should be no cpu sync within xgrammar.
            pin_memory = is_pin_memory_available()
            index_tensor = torch.tensor(
                out_indices, dtype=torch.int32, device="cpu", pin_memory=pin_memory
            )
            index_tensor = index_tensor.to(logits.device, non_blocking=True)
121

122
123
124
125
126
        xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
        return

    # CPU case, use list for indices.
    indices = None if skip_out_indices else out_indices
127
128
    # Handle dtype conversion for CPU (older xgrammar CPU kernels require float32)
    # See: https://github.com/vllm-project/vllm/issues/31901
129
    if logits.dtype != torch.float32:
130
        # Convert to float32, apply bitmask, then convert back
131
132
        logits_fp32 = logits.to(torch.float32)
        xgr.apply_token_bitmask_inplace(logits_fp32, grammar_bitmask, indices=indices)
133
        # Copy the modified values back to the original tensor
134
        logits.copy_(logits_fp32.to(logits.dtype))
135
    else:
136
        xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=indices)
137
138


139
140
141
142
143
144
145
146
147
148
149
class OutlinesVocabulary:
    """
    Wrapper class for `outlines_core.Vocabulary`,
    which allows us to store a hash with the vocabulary
    """

    def __init__(self, vocabulary: oc.Vocabulary) -> None:
        # Actual vocabulary object
        self.inner = vocabulary
        # Have to do abs(hash()) because python hashes can
        # be negative, and we are using hash as a cache key.
150
        hex_str = hashlib.sha256(vocabulary.__repr__().encode("utf-8")).hexdigest()
151
152
153
154
155
156
157
158
159
160
161
162
163
        hash_int = int(hex_str, 16)
        self._hash = hash_int


def get_outlines_cache_path() -> str:
    """Get the context object that contains previously-computed return values"""
    outlines_cache_dir = os.getenv("OUTLINES_CACHE_DIR")
    xdg_cache_home = os.getenv("XDG_CACHE_HOME")
    home_dir = os.path.expanduser("~")

    if outlines_cache_dir:
        # OUTLINES_CACHE_DIR takes precedence
        return outlines_cache_dir
164
    if xdg_cache_home:
165
166
        return os.path.join(xdg_cache_home, ".cache", "outlines")
    # If homedir is "/", we may be inside a container, and thus writing to
167
    # root would be problematic, so we fall back to using a tempfile.
168
    # Also validate the path exists, since os.path.expanduser does
169
    # not guarantee existence.
170
    if os.path.isdir(home_dir) and home_dir != "/":
171
172
173
        # Default Unix fallback: ~/.cache/outlines
        return os.path.join(home_dir, ".cache", "outlines")

174
175
176
    # home_dir may be / inside a docker container without existing user
    tempdir = tempfile.gettempdir()
    return os.path.join(tempdir, ".cache", "outlines")
177
178
179
180
181
182
183


def get_outlines_cache():
    """Get the Cache instance to be used for index caching"""

    cache_dir = get_outlines_cache_path()
    if envs.VLLM_V1_USE_OUTLINES_CACHE:
184
185
186
187
188
        logger.warning(
            "Enabling outlines cache. This is an unbounded on-disk "
            "cache. It may consume a lot of disk space and should "
            "not be used with untrusted clients."
        )
189
190
191
        cache = Cache(cache_dir, eviction_policy="none", cull_limit=0)
        outlines_version = importlib.metadata.version("outlines_core")

192
        cached_version = cache.get("__version__", None)
193
194
        if cached_version != outlines_version:
            cache.clear()
195
        cache.set("__version__", outlines_version)
196
        return cache
197
198

    return LRUCache(maxsize=128)
199
200
201
202
203
204


re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
re_replacement_seq = re.compile(r"^.{0,6}�+.{0,6}$")


205
def _reduced_vocabulary(tokenizer: TokenizerLike) -> dict[bytes, list[int]]:
206
207
208
209
210
    """Create a map from vocabulary tokens to lists of equivalent token ids.

    Returns:
        A Dict of token string -> equivalent token ids
    """
211
    eos_token_id = tokenizer.eos_token_id
212

Harry Mellor's avatar
Harry Mellor committed
213
214
215
    unicode_to_bytes = {
        v: k for k, v in convert_slow_tokenizer.bytes_to_unicode().items()
    }
216
217
218
219
220

    def convert_token_to_string(token: str) -> str:
        string = tokenizer.convert_tokens_to_string([token])

        # A hack to handle missing spaces to HF's Llama tokenizers
221
222
223
224
225
        if (
            type(token) is str
            and token.startswith(file_utils.SPIECE_UNDERLINE)
            or token == "<0x20>"
        ):
226
227
228
229
230
231
232
            return " " + string

        return string

    vocabulary: dict[bytes, list[int]] = {}
    empty_token_ids: list[int] = []
    for token, token_idx in tokenizer.get_vocab().items():
233
        if token in tokenizer.all_special_tokens:
234
235
236
237
238
239
240
241
242
243
244
            continue

        token_str = convert_token_to_string(token)
        if token_str:
            if isinstance(token, (bytes, bytearray)):
                # For BPE tokenizers where tokens are stored as bytes.

                # safe to ignore since token_str is of type (bytearray, bytes)
                # by this point.
                token_bytes = bytes(token_str)  # type: ignore[arg-type]

245
246
247
            elif (token_str == "\ufffd" and token != "\ufffd") or (
                "\ufffd" in token_str and not re_replacement_seq.match(token_str)
            ):
248
249
250
251
252
253
254
255
256
257
                # Handle tokens with invalid UTF-8 sequences.
                if re_llama_byte_token.match(token):
                    # Llama-like tokenizers use <0xXX> for incomplete sequences.
                    token_bytes = bytes([int(token[3:5], 16)])
                else:
                    # GPT2 tokenizers: map each byte back using unicode_to_bytes
                    byte_vals = [unicode_to_bytes.get(c) for c in token]
                    if None in byte_vals:
                        raise RuntimeError(
                            f"Cannot convert token `{token}`"
258
259
                            f" ({token_idx}) to bytes: {token_str}"
                        )
260
261
262
263
                    # safe to ignore, since if None in byte_vals,
                    # an error is thrown.
                    token_bytes = bytes(byte_vals)  # type: ignore[arg-type]
            else:
264
                token_bytes = token_str.encode("utf-8")
265
266
267
268
269
270
271
272
273

            if token_idx != eos_token_id:
                vocabulary.setdefault(token_bytes, []).append(token_idx)
        else:
            empty_token_ids.append(token_idx)

    return vocabulary


274
def get_outlines_vocabulary(tokenizer: TokenizerLike) -> oc.Vocabulary:
275
    """Get the `Vocabulary` object for a given tokenizer."""
276
277
278
    if hasattr(tokenizer, "_outlines_vocabulary"):
        return tokenizer._outlines_vocabulary  # type: ignore

279
280
281
282
283
284
285
    reduced_vocab = _reduced_vocabulary(tokenizer)
    vocabulary = OutlinesVocabulary(
        oc.Vocabulary(tokenizer.eos_token_id, reduced_vocab)
    )
    tokenizer._outlines_vocabulary = vocabulary  # type: ignore

    return vocabulary
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306


def grammar_is_likely_lark(grammar_str: str) -> bool:
    """
    Check if grammar appears to use Lark syntax.

    Args:
        grammar_str: Input grammar string

    Returns:
        bool: True if grammar appears to be in Lark format, False otherwise

    Examples:
        >>> grammar_is_likely_lark("rule: 'abc'")
        True
        >>> grammar_is_likely_lark("rule ::= 'abc'")
        False
    """
    if not grammar_str or not isinstance(grammar_str, str):
        return False

307
    for line in grammar_str.split("\n"):
308
        # Remove both comment styles
309
        line = re.sub(r"(#|//).*$", "", line).strip()
310
311
312
313
        if not line:
            continue

        # Look for EBNF rule definition
314
        if "::=" in line:
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
            return False

    return True


def convert_lark_to_ebnf(grammar_str: str) -> str:
    """
    Convert a Lark grammar string to EBNF format.

    EBNF reference:
    https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
    Lark grammar reference:
    https://lark-parser.readthedocs.io/en/latest/grammar.html

    Args:
        grammar_str: Input grammar in Lark format

    Returns:
        str: Converted grammar in EBNF format

    Examples:
        >>> print(convert_lark_to_ebnf("rule: 'hello'"))
        root ::= rule
        rule ::= "hello"
    """
    if not isinstance(grammar_str, str):
        raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
    if not grammar_str.strip():
        raise ValueError("Grammar string cannot be empty")

    defined_rules = set()
    referenced_rules = set()
    output_lines = []

    def clean_line(line: str) -> str:
        """Remove comments and whitespace from line."""
351
        return re.sub(r"(#|//).*$", "", line).strip()
352
353
354
355

    def check_quotes(text: str, rule_name: str, line_num: int) -> None:
        """Validate quote matching in text."""
        if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
356
            raise ValueError(f"Mismatched quotes in {rule_name} on line {line_num}")
357

358
    def extract_references(text: str) -> set[str]:
359
360
        """Extract rule references from text."""
        # Remove quoted strings and special characters
361
362
363
        text = re.sub(r'"[^"]*"', "", text)
        text = re.sub(r"[+*?()|\[\]{}]", " ", text)
        return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", text))
364
365

    # First pass: Find root rule and validate rule definitions
366
    lines = [clean_line(line) for line in grammar_str.split("\n")]
367
368
369
    first_rule = None

    for line_num, line in enumerate(lines, 1):
370
        if not line or line.startswith("|"):
371
372
            continue

373
        if ":" in line:
374
            try:
375
                name = line.split(":", 1)[0].strip().strip("?")
376
377
378
                defined_rules.add(name)
                if first_rule is None:
                    first_rule = name
379
380
                if name == "start":
                    first_rule = "start"
381
            except IndexError as e:
382
383
384
385
                raise ValueError(
                    f"Invalid rule format on line {line_num}. "
                    "Expected 'rule_name: definition'"
                ) from e
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401

    if not defined_rules:
        raise ValueError("No valid rules found in grammar")

    # Add root rule
    output_lines.append(f"root ::= {first_rule}")

    # Second pass: Process rule definitions and alternatives
    current_rule = None
    current_definition = []

    for line_num, line in enumerate(lines, 1):
        if not line:
            continue

        try:
402
            if ":" in line and not line.startswith("|"):
403
404
405
                # Save previous rule if exists
                if current_rule:
                    output_lines.append(
406
407
                        f"{current_rule} ::= {' | '.join(current_definition)}"
                    )
408
409

                # Process new rule
410
411
                name, definition = line.split(":", 1)
                current_rule = name.strip().strip("?")
412
413
414
415
416
417

                check_quotes(definition, f"rule '{current_rule}'", line_num)
                definition = re.sub(r"'([^']*)'", r'"\1"', definition)
                referenced_rules.update(extract_references(definition))
                current_definition = [definition.strip()]

418
            elif line.startswith("|"):
419
                if not current_rule:
420
421
422
423
                    raise ValueError(
                        f"Alternative '|' on line {line_num} "
                        "without a preceding rule definition"
                    )
424
425

                alt_def = line[1:].strip()
426
427
428
                check_quotes(
                    alt_def, f"alternative for rule '{current_rule}'", line_num
                )
429
430
431
432
433
434
435
436
437
                alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
                referenced_rules.update(extract_references(alt_def))
                current_definition.append(alt_def)

        except ValueError as e:
            raise ValueError(f"Error on line {line_num}: {str(e)}") from e

    # Add final rule if exists
    if current_rule:
438
        output_lines.append(f"{current_rule} ::= {' | '.join(current_definition)}")
439
440

    # Validate all rules are defined
441
    undefined_rules = referenced_rules - defined_rules - {"root"}
442
    if undefined_rules:
443
444
445
        raise ValueError(
            f"Referenced rules are not defined: {', '.join(sorted(undefined_rules))}"
        )
446

447
    return "\n".join(output_lines)
448
449
450
451
452
453


def choice_as_grammar(choice: list[str]) -> str:
    def escape_ebnf_string(s: str) -> str:
        """Escape special characters in a EBNF string."""
        # Escape double quotes and backslashes
454
        return re.sub(r'(["\\])', r"\\\1", s)
455
456

    escaped_choices = (escape_ebnf_string(c) for c in choice)
457
    grammar = "root ::= " + " | ".join(f'"{c}"' for c in escaped_choices)
458
    return grammar