"tests/kernels/attention/test_attention.py" did not exist on "03ffd0a02251e10c1aa14fca8cb0ab1e4e40b886"
utils.py 15.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from __future__ import annotations

5
6
7
import hashlib
import importlib.metadata
import os
8
import tempfile
9
10
from typing import TYPE_CHECKING

11
import numpy as np
12
import regex as re
13
import torch
14
15
16
17
18
from cachetools import LRUCache
from diskcache import Cache

import vllm.envs as envs
from vllm.logger import init_logger
19
from vllm.utils.import_utils import LazyLoader
20
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
21
22
23

if TYPE_CHECKING:
    import outlines_core as oc
Harry Mellor's avatar
Harry Mellor committed
24
    import transformers.convert_slow_tokenizer as convert_slow_tokenizer
25
    import transformers.file_utils as file_utils
26
    import xgrammar as xgr
27

28
    from vllm.tokenizers import TokenizerLike
29
    from vllm.v1.worker.gpu_input_batch import InputBatch
30
else:
31
    xgr = LazyLoader("xgr", globals(), "xgrammar")
32
33
    oc = LazyLoader("oc", globals(), "outlines_core")
    file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
Harry Mellor's avatar
Harry Mellor committed
34
35
    convert_slow_tokenizer = LazyLoader(
        "convert_slow_tokenizer", globals(), "transformers.convert_slow_tokenizer"
36
37
    )

38

39
40
41
42
43
logger = init_logger(__name__)

CACHE = None


44
45
def apply_grammar_bitmask(
    scheduler_output: SchedulerOutput,
46
    grammar_output: GrammarOutput,
47
48
49
50
51
52
53
54
55
56
57
    input_batch: InputBatch,
    logits: torch.Tensor,
) -> None:
    """
    Apply grammar bitmask to output logits of the model with xgrammar function.

    Args:
        scheduler_output (SchedulerOutput): The result of engine scheduling.
        input_batch (InputBatch): The input of model runner.
        logits (torch.Tensor): The output logits of model forward.
    """
58
59
60
    # Serialization of np.ndarray is much more efficient than a tensor,
    # so we receive it in that format.
    grammar_bitmask = grammar_output.grammar_bitmask
61
62
63
64
65
66
67
68
69
70
71
72

    # We receive the structured output bitmask from the scheduler,
    # compacted to contain bitmasks only for structured output requests.
    # The order of the requests in the bitmask is not guaranteed to be the
    # same as the order of the requests in the gpu runner's batch. We need
    # to sort the bitmask to match the order of the requests used here.

    # Get the batch indices of the structured output requests.
    # Keep track of the number of speculative tokens scheduled for every
    # request in the batch, as the logit indices are offset by this amount.
    struct_out_req_batch_indices: dict[str, int] = {}
    cumulative_offset = 0
73
74
75
    spec_tokens = scheduler_output.scheduled_spec_decode_tokens
    struct_out_req_ids = set(grammar_output.structured_output_request_ids)
    for batch_index, req_id in enumerate(input_batch.req_ids):
76
        logit_index = batch_index + cumulative_offset
77
78
        cumulative_offset += len(spec_tokens.get(req_id, ()))
        if req_id in struct_out_req_ids:
79
80
81
82
83
            struct_out_req_batch_indices[req_id] = logit_index

    out_indices = []

    # Reorder the bitmask to match the order of the requests in the batch.
84
85
86
87
88
    sorted_bitmask = np.full(
        shape=(logits.shape[0], grammar_bitmask.shape[1]),
        fill_value=-1,
        dtype=grammar_bitmask.dtype,
    )
89
    cumulative_index = 0
90
    for req_id in grammar_output.structured_output_request_ids:
91
92
        num_spec_tokens = len(spec_tokens.get(req_id, ()))
        if (logit_idx := struct_out_req_batch_indices.get(req_id)) is not None:
93
            for i in range(1 + num_spec_tokens):
94
95
96
                bitmask_index = logit_idx + i
                sorted_bitmask[bitmask_index] = grammar_bitmask[cumulative_index + i]
                out_indices.append(bitmask_index)
97
        cumulative_index += 1 + num_spec_tokens
98
99
100
101
102

    # Copy async to device as tensor.
    grammar_bitmask = torch.from_numpy(sorted_bitmask).to(
        logits.device, non_blocking=True
    )
103
104
105
106
107
108

    # If the length of out indices and the logits have the same shape
    # we don't need to pass indices to the kernel,
    # since the bitmask is already aligned with the logits.
    skip_out_indices = len(out_indices) == logits.shape[0]

109
110
111
112
113
114
115
116
117
    index_tensor = None
    if not skip_out_indices:
        # xgrammar expects a python list of indices but it will actually work with
        # a tensor. If we copy the tensor ourselves here we can do it in a non_blocking
        # manner and there should be no cpu sync within xgrammar.
        index_tensor = torch.tensor(
            out_indices, dtype=torch.int32, device="cpu", pin_memory=True
        )
        index_tensor = index_tensor.to(logits.device, non_blocking=True)
118

119
    xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
120
121


122
123
124
125
126
127
128
129
130
131
132
class OutlinesVocabulary:
    """
    Wrapper class for `outlines_core.Vocabulary`,
    which allows us to store a hash with the vocabulary
    """

    def __init__(self, vocabulary: oc.Vocabulary) -> None:
        # Actual vocabulary object
        self.inner = vocabulary
        # Have to do abs(hash()) because python hashes can
        # be negative, and we are using hash as a cache key.
133
        hex_str = hashlib.sha256(vocabulary.__repr__().encode("utf-8")).hexdigest()
134
135
136
137
138
139
140
141
142
143
144
145
146
        hash_int = int(hex_str, 16)
        self._hash = hash_int


def get_outlines_cache_path() -> str:
    """Get the context object that contains previously-computed return values"""
    outlines_cache_dir = os.getenv("OUTLINES_CACHE_DIR")
    xdg_cache_home = os.getenv("XDG_CACHE_HOME")
    home_dir = os.path.expanduser("~")

    if outlines_cache_dir:
        # OUTLINES_CACHE_DIR takes precedence
        return outlines_cache_dir
147
    if xdg_cache_home:
148
149
        return os.path.join(xdg_cache_home, ".cache", "outlines")
    # If homedir is "/", we may be inside a container, and thus writing to
150
    # root would be problematic, so we fall back to using a tempfile.
151
    # Also validate the path exists, since os.path.expanduser does
152
    # not guarantee existence.
153
    if os.path.isdir(home_dir) and home_dir != "/":
154
155
156
        # Default Unix fallback: ~/.cache/outlines
        return os.path.join(home_dir, ".cache", "outlines")

157
158
159
    # home_dir may be / inside a docker container without existing user
    tempdir = tempfile.gettempdir()
    return os.path.join(tempdir, ".cache", "outlines")
160
161
162
163
164
165
166


def get_outlines_cache():
    """Get the Cache instance to be used for index caching"""

    cache_dir = get_outlines_cache_path()
    if envs.VLLM_V1_USE_OUTLINES_CACHE:
167
168
169
170
171
        logger.warning(
            "Enabling outlines cache. This is an unbounded on-disk "
            "cache. It may consume a lot of disk space and should "
            "not be used with untrusted clients."
        )
172
173
174
        cache = Cache(cache_dir, eviction_policy="none", cull_limit=0)
        outlines_version = importlib.metadata.version("outlines_core")

175
        cached_version = cache.get("__version__", None)
176
177
        if cached_version != outlines_version:
            cache.clear()
178
        cache.set("__version__", outlines_version)
179
        return cache
180
181

    return LRUCache(maxsize=128)
182
183
184
185
186
187


re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
re_replacement_seq = re.compile(r"^.{0,6}�+.{0,6}$")


188
def _reduced_vocabulary(tokenizer: TokenizerLike) -> dict[bytes, list[int]]:
189
190
191
192
193
    """Create a map from vocabulary tokens to lists of equivalent token ids.

    Returns:
        A Dict of token string -> equivalent token ids
    """
194
    eos_token_id = tokenizer.eos_token_id
195

Harry Mellor's avatar
Harry Mellor committed
196
197
198
    unicode_to_bytes = {
        v: k for k, v in convert_slow_tokenizer.bytes_to_unicode().items()
    }
199
200
201
202
203

    def convert_token_to_string(token: str) -> str:
        string = tokenizer.convert_tokens_to_string([token])

        # A hack to handle missing spaces to HF's Llama tokenizers
204
205
206
207
208
        if (
            type(token) is str
            and token.startswith(file_utils.SPIECE_UNDERLINE)
            or token == "<0x20>"
        ):
209
210
211
212
213
214
215
            return " " + string

        return string

    vocabulary: dict[bytes, list[int]] = {}
    empty_token_ids: list[int] = []
    for token, token_idx in tokenizer.get_vocab().items():
216
        if token in tokenizer.all_special_tokens:
217
218
219
220
221
222
223
224
225
226
227
            continue

        token_str = convert_token_to_string(token)
        if token_str:
            if isinstance(token, (bytes, bytearray)):
                # For BPE tokenizers where tokens are stored as bytes.

                # safe to ignore since token_str is of type (bytearray, bytes)
                # by this point.
                token_bytes = bytes(token_str)  # type: ignore[arg-type]

228
229
230
            elif (token_str == "\ufffd" and token != "\ufffd") or (
                "\ufffd" in token_str and not re_replacement_seq.match(token_str)
            ):
231
232
233
234
235
236
237
238
239
240
                # Handle tokens with invalid UTF-8 sequences.
                if re_llama_byte_token.match(token):
                    # Llama-like tokenizers use <0xXX> for incomplete sequences.
                    token_bytes = bytes([int(token[3:5], 16)])
                else:
                    # GPT2 tokenizers: map each byte back using unicode_to_bytes
                    byte_vals = [unicode_to_bytes.get(c) for c in token]
                    if None in byte_vals:
                        raise RuntimeError(
                            f"Cannot convert token `{token}`"
241
242
                            f" ({token_idx}) to bytes: {token_str}"
                        )
243
244
245
246
                    # safe to ignore, since if None in byte_vals,
                    # an error is thrown.
                    token_bytes = bytes(byte_vals)  # type: ignore[arg-type]
            else:
247
                token_bytes = token_str.encode("utf-8")
248
249
250
251
252
253
254
255
256

            if token_idx != eos_token_id:
                vocabulary.setdefault(token_bytes, []).append(token_idx)
        else:
            empty_token_ids.append(token_idx)

    return vocabulary


257
def get_outlines_vocabulary(tokenizer: TokenizerLike) -> oc.Vocabulary:
258
    """Get the `Vocabulary` object for a given tokenizer."""
259
260
261
    if hasattr(tokenizer, "_outlines_vocabulary"):
        return tokenizer._outlines_vocabulary  # type: ignore

262
263
264
265
266
267
268
    reduced_vocab = _reduced_vocabulary(tokenizer)
    vocabulary = OutlinesVocabulary(
        oc.Vocabulary(tokenizer.eos_token_id, reduced_vocab)
    )
    tokenizer._outlines_vocabulary = vocabulary  # type: ignore

    return vocabulary
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289


def grammar_is_likely_lark(grammar_str: str) -> bool:
    """
    Check if grammar appears to use Lark syntax.

    Args:
        grammar_str: Input grammar string

    Returns:
        bool: True if grammar appears to be in Lark format, False otherwise

    Examples:
        >>> grammar_is_likely_lark("rule: 'abc'")
        True
        >>> grammar_is_likely_lark("rule ::= 'abc'")
        False
    """
    if not grammar_str or not isinstance(grammar_str, str):
        return False

290
    for line in grammar_str.split("\n"):
291
        # Remove both comment styles
292
        line = re.sub(r"(#|//).*$", "", line).strip()
293
294
295
296
        if not line:
            continue

        # Look for EBNF rule definition
297
        if "::=" in line:
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
            return False

    return True


def convert_lark_to_ebnf(grammar_str: str) -> str:
    """
    Convert a Lark grammar string to EBNF format.

    EBNF reference:
    https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
    Lark grammar reference:
    https://lark-parser.readthedocs.io/en/latest/grammar.html

    Args:
        grammar_str: Input grammar in Lark format

    Returns:
        str: Converted grammar in EBNF format

    Examples:
        >>> print(convert_lark_to_ebnf("rule: 'hello'"))
        root ::= rule
        rule ::= "hello"
    """
    if not isinstance(grammar_str, str):
        raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
    if not grammar_str.strip():
        raise ValueError("Grammar string cannot be empty")

    defined_rules = set()
    referenced_rules = set()
    output_lines = []

    def clean_line(line: str) -> str:
        """Remove comments and whitespace from line."""
334
        return re.sub(r"(#|//).*$", "", line).strip()
335
336
337
338

    def check_quotes(text: str, rule_name: str, line_num: int) -> None:
        """Validate quote matching in text."""
        if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
339
            raise ValueError(f"Mismatched quotes in {rule_name} on line {line_num}")
340

341
    def extract_references(text: str) -> set[str]:
342
343
        """Extract rule references from text."""
        # Remove quoted strings and special characters
344
345
346
        text = re.sub(r'"[^"]*"', "", text)
        text = re.sub(r"[+*?()|\[\]{}]", " ", text)
        return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", text))
347
348

    # First pass: Find root rule and validate rule definitions
349
    lines = [clean_line(line) for line in grammar_str.split("\n")]
350
351
352
    first_rule = None

    for line_num, line in enumerate(lines, 1):
353
        if not line or line.startswith("|"):
354
355
            continue

356
        if ":" in line:
357
            try:
358
                name = line.split(":", 1)[0].strip().strip("?")
359
360
361
                defined_rules.add(name)
                if first_rule is None:
                    first_rule = name
362
363
                if name == "start":
                    first_rule = "start"
364
            except IndexError as e:
365
366
367
368
                raise ValueError(
                    f"Invalid rule format on line {line_num}. "
                    "Expected 'rule_name: definition'"
                ) from e
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

    if not defined_rules:
        raise ValueError("No valid rules found in grammar")

    # Add root rule
    output_lines.append(f"root ::= {first_rule}")

    # Second pass: Process rule definitions and alternatives
    current_rule = None
    current_definition = []

    for line_num, line in enumerate(lines, 1):
        if not line:
            continue

        try:
385
            if ":" in line and not line.startswith("|"):
386
387
388
                # Save previous rule if exists
                if current_rule:
                    output_lines.append(
389
390
                        f"{current_rule} ::= {' | '.join(current_definition)}"
                    )
391
392

                # Process new rule
393
394
                name, definition = line.split(":", 1)
                current_rule = name.strip().strip("?")
395
396
397
398
399
400

                check_quotes(definition, f"rule '{current_rule}'", line_num)
                definition = re.sub(r"'([^']*)'", r'"\1"', definition)
                referenced_rules.update(extract_references(definition))
                current_definition = [definition.strip()]

401
            elif line.startswith("|"):
402
                if not current_rule:
403
404
405
406
                    raise ValueError(
                        f"Alternative '|' on line {line_num} "
                        "without a preceding rule definition"
                    )
407
408

                alt_def = line[1:].strip()
409
410
411
                check_quotes(
                    alt_def, f"alternative for rule '{current_rule}'", line_num
                )
412
413
414
415
416
417
418
419
420
                alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
                referenced_rules.update(extract_references(alt_def))
                current_definition.append(alt_def)

        except ValueError as e:
            raise ValueError(f"Error on line {line_num}: {str(e)}") from e

    # Add final rule if exists
    if current_rule:
421
        output_lines.append(f"{current_rule} ::= {' | '.join(current_definition)}")
422
423

    # Validate all rules are defined
424
    undefined_rules = referenced_rules - defined_rules - {"root"}
425
    if undefined_rules:
426
427
428
        raise ValueError(
            f"Referenced rules are not defined: {', '.join(sorted(undefined_rules))}"
        )
429

430
    return "\n".join(output_lines)
431
432
433
434
435
436


def choice_as_grammar(choice: list[str]) -> str:
    def escape_ebnf_string(s: str) -> str:
        """Escape special characters in a EBNF string."""
        # Escape double quotes and backslashes
437
        return re.sub(r'(["\\])', r"\\\1", s)
438
439

    escaped_choices = (escape_ebnf_string(c) for c in choice)
440
    grammar = "root ::= " + " | ".join(f'"{c}"' for c in escaped_choices)
441
    return grammar