guided_decoding.py 4.43 KB
Newer Older
1
2
3
4
5
6
7
8
import asyncio
import concurrent.futures
from copy import copy
from enum import Enum
from functools import lru_cache
from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Union, Tuple
9

10
from pydantic import BaseModel
11
from transformers import PreTrainedTokenizerBase
12

13
14
15
from vllm.entrypoints.openai.protocol import (CompletionRequest,
                                              ChatCompletionRequest)
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
16
17
                                                          RegexLogitsProcessor,
                                                          CFGLogitsProcessor)
18
19
20
21
22
23


class GuidedDecodingMode(Enum):
    JSON = "json"
    REGEX = "regex"
    CHOICE = "choice"
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    GRAMMAR = "grammar"


# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
# the main difference is that we changed the start: value to
# start: object | array, so we are denying scalar values as the root of the
# JSON. Starting with scalars as the root seems to cause llama to generate
# without stop.
JSON_GRAMMAR = r"""
?start: object | array

?value: object
| array
| UNESCAPED_STRING
| SIGNED_NUMBER      -> number
| "true"             -> true
| "false"            -> false
| "null"             -> null

array  : "[" [value ("," value)*] "]"
object : "{" [pair ("," pair)*] "}"
pair   : UNESCAPED_STRING ":" value

%import common.UNESCAPED_STRING
%import common.SIGNED_NUMBER
%import common.WS
50

51
52
%ignore WS
"""
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

global_thread_pool = None  # used for generating logits processor fsm


async def get_guided_decoding_logits_processor(
        request: Union[CompletionRequest, ChatCompletionRequest],
        tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
    """
    Given an OpenAI-compatible request, check for guided decoding parameters
    and get the necessary logits processor for the given guide.
    We cache logit processors by (guide, tokenizer), and on cache hit
    we make a shallow copy to reuse the same underlying FSM.
    """
    global global_thread_pool
    guide, mode = _get_guide_and_mode(request)
    if not guide:
        return None

    if global_thread_pool is None:
        global_thread_pool = concurrent.futures.ThreadPoolExecutor(
            max_workers=2)
    loop = asyncio.get_running_loop()

    result = await loop.run_in_executor(global_thread_pool,
                                        _get_cached_logits_processor, guide,
                                        tokenizer, mode)

    logits_processor = copy(result)
    # reset logits processor's internal state
    logits_processor.init_state()
    return logits_processor


def _get_guide_and_mode(
    request: Union[CompletionRequest, ChatCompletionRequest]
) -> Tuple[str, GuidedDecodingMode]:

    if request.guided_json:
        json = request.guided_json
        if isinstance(json, dict):
            # turn dict into hashable string
            json = json_dumps(json, sort_keys=True)
        elif isinstance(json, BaseModel):
            # use pydantic signature so that different model classes
            # with the same fields will get hashed the same
            json = str(json.__signature__)
        return json, GuidedDecodingMode.JSON
    elif request.guided_regex:
        return request.guided_regex, GuidedDecodingMode.REGEX
    elif request.guided_choice:
        # choice just uses regex
        choices = [
            regex_escape(str(choice)) for choice in request.guided_choice
        ]
        choices_regex = "(" + "|".join(choices) + ")"
        return choices_regex, GuidedDecodingMode.CHOICE
109
110
111
112
113
    elif request.guided_grammar:
        return request.guided_grammar, GuidedDecodingMode.GRAMMAR
    elif (request.response_format is not None
          and request.response_format.type == "json_object"):
        return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
114
115
116
117
118
    else:
        return None, None


@lru_cache(maxsize=32)
119
120
def _get_cached_logits_processor(guide: str,
                                 tokenizer: PreTrainedTokenizerBase,
121
122
123
124
125
                                 mode: GuidedDecodingMode):
    if mode == GuidedDecodingMode.JSON:
        return JSONLogitsProcessor(guide, tokenizer)
    elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
        return RegexLogitsProcessor(guide, tokenizer)
126
127
    elif mode == GuidedDecodingMode.GRAMMAR:
        return CFGLogitsProcessor(guide, tokenizer)
128
129
    else:
        raise ValueError(f"Unknown guided decoding mode {mode}")