backend_outlines.py 11.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright 2025-present the Outlines developers
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations

import ast
import importlib
import json
import sys
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import torch
from regex import escape as regex_escape

from vllm.sampling_params import SamplingParams
from vllm.utils import LazyLoader
18
19
20
21
22
23
24
25
26
27
from vllm.v1.structured_output.backend_types import (
    StructuredOutputBackend,
    StructuredOutputGrammar,
    StructuredOutputOptions,
)
from vllm.v1.structured_output.utils import (
    OutlinesVocabulary,
    get_outlines_cache,
    get_outlines_vocabulary,
)
28
29
30
31
32
33

if TYPE_CHECKING:
    import outlines_core as oc
    import outlines_core.json_schema as json_schema
else:
    oc = LazyLoader("oc", globals(), "outlines_core")
34
    json_schema = LazyLoader("json_schema", globals(), "outlines_core.json_schema")
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

# Python 3.11+ sre_parse and sre_constants
# are deprecated, so we must import them from re
if sys.version_info >= (3, 11):
    # Hack to get around pre-commit regex module rule
    # because going through re is the only way to get sre_parse
    # and sre_constants in Python 3.11+
    _re = importlib.import_module("re")
    sre_parse = _re._parser
    sre_constants = _re._constants
else:
    import sre_constants
    import sre_parse


@dataclass
class OutlinesBackend(StructuredOutputBackend):
    def __post_init__(self):
53
54
        self.vocabulary = get_outlines_vocabulary(self.tokenizer)
        self.cache = get_outlines_cache()
55

56
57
58
    def _compile_index(
        self, regex_string: str, vocabulary: OutlinesVocabulary
    ) -> oc.Index:
59
60
61
62
63
64
65
66
67
        cache_key = f"{vocabulary._hash}_{regex_string}"
        if cache_key in self.cache:
            return self.cache[cache_key]

        index = oc.Index(regex_string, vocabulary.inner)
        self.cache[cache_key] = index

        return index

68
69
70
    def compile_grammar(
        self, request_type: StructuredOutputOptions, grammar_spec: str
    ) -> StructuredOutputGrammar:
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        if request_type == StructuredOutputOptions.JSON:
            regex = json_schema.build_regex_from_schema(grammar_spec)
        elif request_type == StructuredOutputOptions.REGEX:
            regex = grammar_spec
        elif request_type == StructuredOutputOptions.CHOICE:
            choices = ast.literal_eval(grammar_spec)
            choices = [regex_escape(c) for c in choices]
            regex = "(" + "|".join(choices) + ")"
        else:
            raise ValueError(
                f"Invalid request type for Outlines backend ({request_type!s})"
            )
        index = self._compile_index(regex, self.vocabulary)
        max_rollback_tokens = (
            self.vllm_config.speculative_config.num_speculative_tokens
86
87
88
89
90
91
92
            if self.vllm_config.speculative_config is not None
            else 0
        )
        return OutlinesGrammar(
            vocab_size=self.vocab_size,
            guide=oc.Guide(index, max_rollback=max_rollback_tokens),
        )
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
        return torch.full(
            (max_num_seqs, (self.vocab_size + 31) // 32),
            -1,
            dtype=torch.int32,
            pin_memory=torch.cuda.is_available(),
        )

    def destroy(self):
        pass


@dataclass
class OutlinesGrammar(StructuredOutputGrammar):
    vocab_size: int
    guide: oc.Guide = field(hash=False)
110
111
112
    num_processed_tokens: int = field(
        default_factory=lambda: 0, repr=False, hash=False, init=False
    )
113
114
115

    # outlines_core signals done on DFA accept; vLLM expects done after EOS.
    # We delay the finished flag by one step so EOS can still be emitted.
116
    _prev_finished: bool = field(default=False, init=False, repr=False, hash=False)
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

    def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
        """Accepts a list of tokens and advances the FSM.

        Returns True if the FSM was advanced successfully.
        Returns False if the FSM failed to advance.
        """
        if self.guide.accepts_tokens(tokens):
            # Advance cannot fail because we checked Guide.accepts_tokens()
            for t in tokens:
                self.guide.advance(t)
                self.num_processed_tokens += 1
            return True
        return False

    def rollback(self, num_tokens: int) -> None:
        self.guide.rollback_state(num_tokens)
        self.num_processed_tokens -= num_tokens

    def validate_tokens(self, tokens: list[int]) -> list[int]:
        accepted: list[int] = []
        for tok in tokens:
            accepted.append(tok)
            if not self.guide.accepts_tokens(accepted):
                accepted.pop()
                break
        return accepted

    def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
        mask = bitmask[idx]
147
        self.guide.write_mask_into(mask.data_ptr(), mask.numel(), mask.element_size())
148
149
150
151
152
153
154
155
156
157
158
159
160
161

    def is_terminated(self) -> bool:
        curr = self.guide.is_finished()
        prev = self._prev_finished
        self._prev_finished = curr
        return prev

    def reset(self):
        self.num_processed_tokens = 0
        self._prev_finished = False
        self.guide.reset()


def validate_structured_output_request_outlines(params: SamplingParams):
162
    if params.structured_outputs is None:
163
164
        return

165
    so_params = params.structured_outputs
166

167
168
169
170
    if so_params.regex:
        validate_regex_is_buildable(so_params.regex)
    elif so_params.json:
        if isinstance(so_params.json, str):
171
172
            try:
                # make sure schema is valid json
173
174
                json.loads(so_params.json)
                schema = so_params.json
175
176
177
178
            except json.JSONDecodeError as e:
                raise ValueError("Invalid JSON grammar specification.") from e
        else:
            try:
179
                schema = json.dumps(so_params.json)
180
181
            except Exception as e:
                raise ValueError(
182
                    f"Error serializing structured outputs jsonschema: {e}"
183
184
185
                ) from e
        pattern = json_schema.build_regex_from_schema(schema)
        validate_regex_is_buildable(pattern)
186
187
    elif so_params.choice:
        choices = [regex_escape(str(choice)) for choice in so_params.choice]
188
189
        regex = "(" + "|".join(choices) + ")"
        validate_regex_is_buildable(regex)
190
    elif so_params.grammar:
191
192
193
194
        raise ValueError(
            "Outlines structured outputs backend "
            "does not support grammar specifications"
        )
195
196
197
198
199
200
201


def _prefix_needs_context(parsed) -> bool:
    """Return True if there's a look-around/anchor before any consumer."""

    def subpattern_consumes(parsed) -> bool:
        """Return True if subpattern can consume at least one character."""
202
        tokens = parsed.data if hasattr(parsed, "data") else parsed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        for ttype, tval in tokens:
            # literal, character class, or dot always consumes
            if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
                return True
            # quantified subpattern: check inner pattern
            elif ttype == sre_parse.MAX_REPEAT:
                _, mx, sub = tval
                if mx != 0 and subpattern_consumes(sub):
                    return True
            # alternation: if any branch consumes, the whole does
            elif ttype == sre_parse.BRANCH:
                _, branches = tval
                if any(subpattern_consumes(br) for br in branches):
                    return True
            # grouped subpattern: recurse into its contents
218
            elif ttype == sre_parse.SUBPATTERN and subpattern_consumes(tval[3]):
219
220
221
222
                return True
        # No consumers, return False
        return False

223
    tokens = parsed.data if hasattr(parsed, "data") else parsed
224
225
    for ttype, tval in tokens:
        # Direct anchors or look-around
226
227
228
229
        if ttype == sre_parse.AT or ttype in (
            sre_constants.ASSERT,
            sre_constants.ASSERT_NOT,
        ):
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
            return True

        # Nested subpattern: check
        if ttype == sre_parse.SUBPATTERN:
            # tval: (group, add_flags, del_flags, subpattern)
            if _prefix_needs_context(tval[3]):
                return True
            if subpattern_consumes(tval[3]):
                return False

        # if any branch has a prefix anchor => True,
        # else if at least one branch consumes => prefix ends => False
        elif ttype == sre_parse.BRANCH:
            saw_consumer = False
            for br in tval[1]:
                if _prefix_needs_context(br):
                    return True
                if subpattern_consumes(br):
                    saw_consumer = True
            if saw_consumer:
                return False

        # Immediate consumer tokens
        elif ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
            return False

        # if subpattern has anchor => True, if it can consume => stop
        elif ttype == sre_parse.MAX_REPEAT:
            if _prefix_needs_context(tval[2]):
                return True
            if subpattern_consumes(tval[2]):
                return False

    return False


def _check_unsupported(parsed) -> None:
    """Check for regex features unsupported by regex-automata"""
268
    tokens = parsed.data if hasattr(parsed, "data") else parsed
269
270
271
272
273
274
275
276
277
278
279
    for ttype, tval in tokens:
        # backreference
        if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS):
            raise ValueError("Backreferences are unsupported.")

        # look-around assertion
        elif ttype in (sre_constants.ASSERT, sre_constants.ASSERT_NOT):
            raise ValueError("Look-Around assertion are unsupported.")

        # unicode word boundaries
        elif ttype == sre_parse.AT:
280
            if tval in (sre_constants.AT_BOUNDARY, sre_constants.AT_NON_BOUNDARY):
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
                raise ValueError("Unicode word boundaries are unsupported.")

        elif ttype == sre_parse.BRANCH:
            # tval is (None, branches)
            for branch in tval[1]:
                _check_unsupported(branch)

        # tval is (min, max, subpattern)
        elif ttype == sre_parse.MAX_REPEAT:
            _check_unsupported(tval[2])


def validate_regex_is_buildable(pattern: str) -> None:
    """
    Validates that the input regex is not using unsupported features
    of the `regex-automata` crate (outlines_core regex engine) and has a
    universal start state.
    definition of universal start state used can be found at:
    https://docs.rs/regex-automata/latest/regex_automata/dfa/trait.Automaton.html#method.universal_start_state
    """
    try:
        parsed = sre_parse.parse(pattern)

    except sre_constants.error as e:
        raise ValueError(f"Error parsing regex: {e}") from e

    try:
        _check_unsupported(parsed)
    except ValueError as e:
        raise ValueError(
311
            f"Regex uses unsupported feature for structured outputs: {e}. "
312
            "Only basic matching constructs are supported—lookarounds, "
313
314
            "backreferences, and unicode boundaries are not."
        ) from e
315
316
317
318
319
320

    if _prefix_needs_context(parsed):
        raise ValueError(
            "Regex does not have a anchored universal start state"
            "This means that the Regex uses anchors (^) or look-arounds "
            "in a way which requires context before any token is matched."
321
            "structured outputs needs regexes that can match without needing "
322
            "that context. Try rewriting the pattern without using these "
323
324
            f"constructs. Pattern:\n{pattern}"
        )