Unverified Commit 654bc5ca authored by Yihuan Bu's avatar Yihuan Bu Committed by GitHub
Browse files

Support for guided decoding for offline LLM (#6878)


Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 825b0448
...@@ -111,6 +111,7 @@ autodoc_mock_imports = [ ...@@ -111,6 +111,7 @@ autodoc_mock_imports = [
"tqdm", "tqdm",
"tensorizer", "tensorizer",
"pynvml", "pynvml",
"outlines",
] ]
for mock_target in autodoc_mock_imports: for mock_target in autodoc_mock_imports:
......
import pytest import pytest
@pytest.fixture
def sample_prompts():
return [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
@pytest.fixture
def sample_token_ids():
return [
[0],
[0, 1],
[0, 2, 1],
[0, 3, 1, 2],
]
@pytest.fixture @pytest.fixture
def sample_regex(): def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
...@@ -66,4 +86,4 @@ column: "col_1" | "col_2" ...@@ -66,4 +86,4 @@ column: "col_1" | "col_2"
table: "table_1" | "table_2" table: "table_1" | "table_2"
condition: column "=" number condition: column "=" number
number: "1" | "2" number: "1" | "2"
""") """)
\ No newline at end of file
import json
import re
import weakref
import jsonschema
import pytest
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from ...conftest import cleanup
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME, max_model_len=1024)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()
@pytest.mark.skip_global_cleanup
def test_guided_regex(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert re.fullmatch(sample_regex, generated_text) is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
def test_guided_json_completion(sample_json_schema, llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
)
outputs = llm.generate(
prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_json=sample_json_schema))
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.skip_global_cleanup
def test_guided_choice_completion(sample_guided_choice, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_choice=sample_guided_choice))
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
def test_guided_grammar(sample_sql_statements, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
)
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_grammar=sample_sql_statements))
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_statements)
parser.parse(generated_text)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")
assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
...@@ -10,6 +10,9 @@ from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt, ...@@ -10,6 +10,9 @@ from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
parse_and_batch_prompt) parse_and_batch_prompt)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -262,6 +265,8 @@ class LLM: ...@@ -262,6 +265,8 @@ class LLM:
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None
) -> List[RequestOutput]: ) -> List[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -303,6 +308,14 @@ class LLM: ...@@ -303,6 +308,14 @@ class LLM:
else: else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1:
raise ValueError(
"You can only use one guided decoding but multiple is "
f"specified: {guided_options_request}")
guided_options_request = GuidedDecodingRequest(
**guided_options_request)
if sampling_params is None: if sampling_params is None:
# Use default sampling params. # Use default sampling params.
sampling_params = SamplingParams() sampling_params = SamplingParams()
...@@ -311,7 +324,8 @@ class LLM: ...@@ -311,7 +324,8 @@ class LLM:
inputs=inputs, inputs=inputs,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request,
guided_options=guided_options_request)
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput) return LLMEngine.validate_outputs(outputs, RequestOutput)
...@@ -508,6 +522,7 @@ class LLM: ...@@ -508,6 +522,7 @@ class LLM:
Sequence[PoolingParams]], Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
guided_options: Optional[GuidedDecodingRequest] = None,
) -> None: ) -> None:
if isinstance(inputs, (str, dict)): if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
...@@ -523,6 +538,15 @@ class LLM: ...@@ -523,6 +538,15 @@ class LLM:
raise ValueError("The lengths of prompts and lora_request " raise ValueError("The lengths of prompts and lora_request "
"must be the same.") "must be the same.")
if isinstance(params, list):
params = [
self._add_guided_processor(param, guided_options)
if isinstance(param, SamplingParams) else param
for param in params
]
elif isinstance(params, SamplingParams):
params = self._add_guided_processor(params, guided_options)
# Add requests to the engine. # Add requests to the engine.
for i, request_inputs in enumerate(inputs): for i, request_inputs in enumerate(inputs):
self._add_request( self._add_request(
...@@ -548,6 +572,24 @@ class LLM: ...@@ -548,6 +572,24 @@ class LLM:
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
def _add_guided_processor(
self,
params: SamplingParams,
guided_options: Optional[GuidedDecodingRequest] = None):
if guided_options:
if guided_options.guided_decoding_backend is None:
decoding_config = self.llm_engine.get_decoding_config()
guided_options.guided_decoding_backend = (
decoding_config.guided_decoding_backend)
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
guided_options.guided_decoding_backend, guided_options,
self.get_tokenizer())
if guided_logits_processor:
if params.logits_processors is None:
params.logits_processors = []
params.logits_processors.append(guided_logits_processor)
return params
def _run_engine( def _run_engine(
self, *, use_tqdm: bool self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
......
# Adapted from # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time import time
from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
import torch import torch
...@@ -14,6 +15,23 @@ from vllm.pooling_params import PoolingParams ...@@ -14,6 +15,23 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
try:
from sphinx.ext.autodoc.mock import _MockModule
if isinstance(torch, _MockModule):
_LONG_INFO = _MOCK_LONG_INFO
else:
_LONG_INFO = torch.iinfo(torch.long)
except ModuleNotFoundError:
_LONG_INFO = torch.iinfo(torch.long)
assert _LONG_INFO.min == _MOCK_LONG_INFO.min
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
class OpenAIBaseModel(BaseModel): class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields # OpenAI API does not allow extra fields
...@@ -108,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -108,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
n: Optional[int] = 1 n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
seed: Optional[int] = Field(None, seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None stream_options: Optional[StreamOptions] = None
...@@ -327,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -327,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel):
max_tokens: Optional[int] = 16 max_tokens: Optional[int] = 16
n: int = 1 n: int = 1
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
seed: Optional[int] = Field(None, seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None stream_options: Optional[StreamOptions] = None
......
...@@ -3,9 +3,10 @@ from typing import Optional, Union ...@@ -3,9 +3,10 @@ from typing import Optional, Union
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest) CompletionRequest)
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( from vllm.model_executor.guided_decoding.guided_fields import (
get_lm_format_enforcer_guided_decoding_logits_processor) GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import ( from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor) get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor from vllm.sampling_params import LogitsProcessor
...@@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor( ...@@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor(
return await get_outlines_guided_decoding_logits_processor( return await get_outlines_guided_decoding_logits_processor(
request, tokenizer) request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer': if guided_decoding_backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_lm_format_enforcer_guided_decoding_logits_processor)
return await get_lm_format_enforcer_guided_decoding_logits_processor( return await get_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer) request, tokenizer)
...@@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor( ...@@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor(
"Must be one of 'outlines, 'lm-format-enforcer'") "Must be one of 'outlines, 'lm-format-enforcer'")
def get_local_guided_decoding_logits_processor(
guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
tokenizer) -> Optional[LogitsProcessor]:
# request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines':
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")
def _adapt_request_for_tool_use(request: Union[CompletionRequest, def _adapt_request_for_tool_use(request: Union[CompletionRequest,
ChatCompletionRequest]): ChatCompletionRequest]):
# the legacy completion API does not support tool use # the legacy completion API does not support tool use
......
from dataclasses import dataclass
from typing import Dict, List, Optional, TypedDict, Union
from pydantic import BaseModel
class LLMGuidedOptions(TypedDict, total=False):
guided_json: Union[Dict, BaseModel, str]
guided_regex: str
guided_choice: List[str]
guided_grammar: str
guided_decoding_backend: str
guided_whitespace_pattern: str
guided_json_object: bool
@dataclass
class GuidedDecodingRequest:
"""One of the fields will be used to retrieve the logit processor."""
guided_json: Optional[Union[Dict, BaseModel, str]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
guided_grammar: Optional[str] = None
guided_decoding_backend: Optional[str] = None
guided_whitespace_pattern: Optional[str] = None
guided_json_object: Optional[bool] = None
def __post_init__(self):
"""Validate that some fields are mutually exclusive."""
guide_count = sum([
self.guided_json is not None, self.guided_regex is not None,
self.guided_choice is not None, self.guided_grammar is not None,
self.guided_json_object is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding but multiple are "
f"specified: {self.__dict__}")
...@@ -12,7 +12,10 @@ from transformers import PreTrainedTokenizerBase ...@@ -12,7 +12,10 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest) CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import ( from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor) get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor from vllm.sampling_params import LogitsProcessor
...@@ -54,6 +57,42 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( ...@@ -54,6 +57,42 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
return logits_processor return logits_processor
def get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options: GuidedDecodingRequest,
tokenizer) -> Optional[LogitsProcessor]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer)
character_level_parser: CharacterLevelParser
if guided_options.guided_json:
schema = _normalize_json_schema_object(guided_options.guided_json)
character_level_parser = JsonSchemaParser(schema)
elif guided_options.guided_choice:
character_level_parser = UnionParser(
[StringParser(choice) for choice in guided_options.guided_choice])
elif guided_options.guided_regex:
character_level_parser = RegexParser(guided_options.guided_regex)
elif guided_options.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
elif guided_options.guided_json_object:
# None means any json object
character_level_parser = JsonSchemaParser(None)
else:
return None
logits_processor = build_vllm_logits_processor(tokenizer_data,
character_level_parser)
return logits_processor
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
if isinstance(schema, str): if isinstance(schema, str):
return json_loads(schema) return json_loads(schema)
......
...@@ -10,6 +10,8 @@ from transformers import PreTrainedTokenizerBase ...@@ -10,6 +10,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest) CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_logits_processors import ( from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
...@@ -77,8 +79,27 @@ async def get_outlines_guided_decoding_logits_processor( ...@@ -77,8 +79,27 @@ async def get_outlines_guided_decoding_logits_processor(
mode, request.guided_whitespace_pattern) mode, request.guided_whitespace_pattern)
def get_local_outlines_guided_decoding_logits_processor(
guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
guide, mode = _get_guide_and_mode(guided_options)
if not guide or not mode:
return None
return _get_logits_processor(guide, tokenizer, mode,
guided_options.guided_whitespace_pattern)
def _get_guide_and_mode( def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest] request: Union[CompletionRequest, ChatCompletionRequest,
GuidedDecodingRequest]
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
if request.guided_json: if request.guided_json:
...@@ -102,7 +123,8 @@ def _get_guide_and_mode( ...@@ -102,7 +123,8 @@ def _get_guide_and_mode(
return choices_regex, GuidedDecodingMode.CHOICE return choices_regex, GuidedDecodingMode.CHOICE
elif request.guided_grammar: elif request.guided_grammar:
return request.guided_grammar, GuidedDecodingMode.GRAMMAR return request.guided_grammar, GuidedDecodingMode.GRAMMAR
elif (request.response_format is not None elif (not isinstance(request, GuidedDecodingRequest)
and request.response_format is not None
and request.response_format.type == "json_object"): and request.response_format.type == "json_object"):
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment