Unverified Commit cef0553d authored by drbh's avatar drbh Committed by GitHub
Browse files

Outlines guided generation (#1539)

This WIP PR starts to add grammar support via outlines, currently this
PR supports very simple regex grammars and does not optimize for
precompiling or caching grammar fsm's.

todo:
- [X] add simple outlines guidance to `NextTokenChooser`
- [X] update protos for grammar
- [X] update generation params API
- [X] constrain simple grammar
- [ ] support parsing more complex grammar into fsm
- [ ] support all outline support grammar types
- [ ] explore optimizations to avoid recompiling grammars

guided request
```bash
curl -s 'http://localhost:3000/generate' \
--header 'Content-Type: application/json' \
--data-raw '{
    "inputs": "make an email for david: \n",
    "parameters": {
        "max_new_tokens": 6,
        "grammar": "[\\w-]+@([\\w-]+\\.)+[\\w-]+"
    }
}' | jq
```
response
```json
{
  "generated_text": "david@example.com"
}
```

unguided request
```bash
curl -s 'http://localhost:3000/generate' \
--header 'Content-Type: application/json' \
--data '{
    "inputs": "make an email for david: \n",
    "parameters": {
        "max_new_tokens": 6
    }
}' | jq
```
response
```json
{
  "generated_text": "    email = 'david"
}
```
parent 4c2848b2
...@@ -8,7 +8,7 @@ use crate::app::App; ...@@ -8,7 +8,7 @@ use crate::app::App;
use crate::event::Event; use crate::event::Event;
use crossterm::ExecutableCommand; use crossterm::ExecutableCommand;
use std::io; use std::io;
use text_generation_client::{NextTokenChooserParameters, ShardedClient}; use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc}; use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend; use tui::backend::CrosstermBackend;
...@@ -45,6 +45,8 @@ pub async fn run( ...@@ -45,6 +45,8 @@ pub async fn run(
repetition_penalty: repetition_penalty.unwrap_or(1.0), repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0), frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark, watermark,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}; };
// Initialize terminal properties // Initialize terminal properties
......
...@@ -10,6 +10,7 @@ from text_generation.types import ( ...@@ -10,6 +10,7 @@ from text_generation.types import (
Response, Response,
Request, Request,
Parameters, Parameters,
Grammar,
) )
from text_generation.errors import parse_error from text_generation.errors import parse_error
...@@ -76,6 +77,7 @@ class Client: ...@@ -76,6 +77,7 @@ class Client:
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: Optional[Grammar] = None,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text Given a prompt, generate the following text
...@@ -138,6 +140,7 @@ class Client: ...@@ -138,6 +140,7 @@ class Client:
watermark=watermark, watermark=watermark,
decoder_input_details=decoder_input_details, decoder_input_details=decoder_input_details,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
grammar=grammar,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
...@@ -169,6 +172,7 @@ class Client: ...@@ -169,6 +172,7 @@ class Client:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: Optional[Grammar] = None,
) -> Iterator[StreamResponse]: ) -> Iterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens Given a prompt, generate the following stream of tokens
...@@ -227,6 +231,7 @@ class Client: ...@@ -227,6 +231,7 @@ class Client:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
grammar=grammar,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)
...@@ -326,6 +331,7 @@ class AsyncClient: ...@@ -326,6 +331,7 @@ class AsyncClient:
watermark: bool = False, watermark: bool = False,
decoder_input_details: bool = False, decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: Optional[Grammar] = None,
) -> Response: ) -> Response:
""" """
Given a prompt, generate the following text asynchronously Given a prompt, generate the following text asynchronously
...@@ -370,6 +376,7 @@ class AsyncClient: ...@@ -370,6 +376,7 @@ class AsyncClient:
Returns: Returns:
Response: generated response Response: generated response
""" """
# Validate parameters # Validate parameters
parameters = Parameters( parameters = Parameters(
best_of=best_of, best_of=best_of,
...@@ -388,6 +395,7 @@ class AsyncClient: ...@@ -388,6 +395,7 @@ class AsyncClient:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
grammar=grammar,
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
...@@ -417,6 +425,7 @@ class AsyncClient: ...@@ -417,6 +425,7 @@ class AsyncClient:
typical_p: Optional[float] = None, typical_p: Optional[float] = None,
watermark: bool = False, watermark: bool = False,
top_n_tokens: Optional[int] = None, top_n_tokens: Optional[int] = None,
grammar: Optional[Grammar] = None,
) -> AsyncIterator[StreamResponse]: ) -> AsyncIterator[StreamResponse]:
""" """
Given a prompt, generate the following stream of tokens asynchronously Given a prompt, generate the following stream of tokens asynchronously
...@@ -475,6 +484,7 @@ class AsyncClient: ...@@ -475,6 +484,7 @@ class AsyncClient:
typical_p=typical_p, typical_p=typical_p,
watermark=watermark, watermark=watermark,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
grammar=grammar,
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)
......
from enum import Enum from enum import Enum
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from typing import Optional, List from typing import Optional, List, Union
from text_generation.errors import ValidationError from text_generation.errors import ValidationError
# enum for grammar type
class GrammarType(str, Enum):
Json = "json"
Regex = "regex"
# Grammar type and value
class Grammar(BaseModel):
# Grammar type
type: GrammarType
# Grammar value
value: Union[str, dict]
class Parameters(BaseModel): class Parameters(BaseModel):
# Activate logits sampling # Activate logits sampling
do_sample: bool = False do_sample: bool = False
...@@ -41,6 +55,8 @@ class Parameters(BaseModel): ...@@ -41,6 +55,8 @@ class Parameters(BaseModel):
decoder_input_details: bool = False decoder_input_details: bool = False
# Return the N most likely tokens at each step # Return the N most likely tokens at each step
top_n_tokens: Optional[int] = None top_n_tokens: Optional[int] = None
# grammar to use for generation
grammar: Optional[Grammar] = None
@validator("best_of") @validator("best_of")
def valid_best_of(cls, field_value, values): def valid_best_of(cls, field_value, values):
...@@ -109,6 +125,14 @@ class Parameters(BaseModel): ...@@ -109,6 +125,14 @@ class Parameters(BaseModel):
raise ValidationError("`top_n_tokens` must be strictly positive") raise ValidationError("`top_n_tokens` must be strictly positive")
return v return v
@validator("grammar")
def valid_grammar(cls, v):
if v is not None:
if v.type == GrammarType.Regex and not v.value:
raise ValidationError("`value` cannot be empty for `regex` grammar")
if v.type == GrammarType.Json and not v.value:
raise ValidationError("`value` cannot be empty for `json` grammar")
return v
class Request(BaseModel): class Request(BaseModel):
# Prompt # Prompt
...@@ -157,7 +181,7 @@ class Token(BaseModel): ...@@ -157,7 +181,7 @@ class Token(BaseModel):
# Token text # Token text
text: str text: str
# Logprob # Logprob
logprob: float logprob: Optional[float] = None
# Is the token a special token # Is the token a special token
# Can be used to ignore tokens when concatenating # Can be used to ignore tokens when concatenating
special: bool special: bool
......
...@@ -378,6 +378,14 @@ Options: ...@@ -378,6 +378,14 @@ Options:
[env: TOKENIZER_CONFIG_PATH=] [env: TOKENIZER_CONFIG_PATH=]
```
## DISABLE_GRAMMAR_SUPPORT
```shell
--disable-grammar-support
Disable outlines grammar constrained generation. This is a feature that allows you to generate text that follows a specific grammar
[env: DISABLE_GRAMMAR_SUPPORT=]
``` ```
## ENV ## ENV
```shell ```shell
......
...@@ -16,7 +16,14 @@ from syrupy.extensions.json import JSONSnapshotExtension ...@@ -16,7 +16,14 @@ from syrupy.extensions.json import JSONSnapshotExtension
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from text_generation import AsyncClient from text_generation import AsyncClient
from text_generation.types import Response, Details, InputToken, Token, BestOfSequence from text_generation.types import (
Response,
Details,
InputToken,
Token,
BestOfSequence,
Grammar,
)
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
...@@ -224,6 +231,7 @@ def launcher(event_loop): ...@@ -224,6 +231,7 @@ def launcher(event_loop):
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True, use_flash_attention: bool = True,
disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
...@@ -247,6 +255,8 @@ def launcher(event_loop): ...@@ -247,6 +255,8 @@ def launcher(event_loop):
env = os.environ env = os.environ
if disable_grammar_support:
args.append("--disable-grammar-support")
if num_shard is not None: if num_shard is not None:
args.extend(["--num-shard", str(num_shard)]) args.extend(["--num-shard", str(num_shard)])
if quantize is not None: if quantize is not None:
...@@ -287,12 +297,15 @@ def launcher(event_loop): ...@@ -287,12 +297,15 @@ def launcher(event_loop):
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True, use_flash_attention: bool = True,
disable_grammar_support: bool = False,
dtype: Optional[str] = None, dtype: Optional[str] = None,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
args = ["--model-id", model_id, "--env"] args = ["--model-id", model_id, "--env"]
if disable_grammar_support:
args.append("--disable-grammar-support")
if num_shard is not None: if num_shard is not None:
args.extend(["--num-shard", str(num_shard)]) args.extend(["--num-shard", str(num_shard)])
if quantize is not None: if quantize is not None:
...@@ -370,11 +383,22 @@ def launcher(event_loop): ...@@ -370,11 +383,22 @@ def launcher(event_loop):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def generate_load(): def generate_load():
async def generate_load_inner( async def generate_load_inner(
client: AsyncClient, prompt: str, max_new_tokens: int, n: int client: AsyncClient,
prompt: str,
max_new_tokens: int,
n: int,
seed: Optional[int] = None,
grammar: Optional[Grammar] = None,
stop_sequences: Optional[List[str]] = None,
) -> List[Response]: ) -> List[Response]:
futures = [ futures = [
client.generate( client.generate(
prompt, max_new_tokens=max_new_tokens, decoder_input_details=True prompt,
max_new_tokens=max_new_tokens,
decoder_input_details=True,
seed=seed,
grammar=grammar,
stop_sequences=stop_sequences,
) )
for _ in range(n) for _ in range(n)
] ]
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -13.90625,
"text": "Test"
},
{
"id": 2009,
"logprob": -12.328125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0566406,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -1.5253906,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.7578125,
"special": false,
"text": "I"
},
{
"id": 4966,
"logprob": -1.9033203,
"special": false,
"text": " hope"
},
{
"id": 445,
"logprob": -0.5019531,
"special": false,
"text": " this"
},
{
"id": 6911,
"logprob": -0.21264648,
"special": false,
"text": " helps"
},
{
"id": 29991,
"logprob": -0.5991211,
"special": false,
"text": "!"
},
{
"id": 2803,
"logprob": -0.37475586,
"special": false,
"text": " Let"
},
{
"id": 592,
"logprob": -0.018463135,
"special": false,
"text": " me"
},
{
"id": 1073,
"logprob": -0.0008597374,
"special": false,
"text": " know"
}
],
"top_tokens": null
},
"generated_text": "\n\nI hope this helps! Let me know"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 30,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 5235,
"logprob": -10.0625,
"text": "info"
},
{
"id": 29901,
"logprob": -3.2324219,
"text": ":"
},
{
"id": 13260,
"logprob": -10.625,
"text": "dav"
},
{
"id": 333,
"logprob": -0.08276367,
"text": "id"
},
{
"id": 8753,
"logprob": -7.5273438,
"text": "hol"
},
{
"id": 17559,
"logprob": -3.8476562,
"text": "tz"
},
{
"id": 763,
"logprob": -10.140625,
"text": "like"
},
{
"id": 10697,
"logprob": -10.1953125,
"text": "trees"
},
{
"id": 322,
"logprob": -2.5742188,
"text": "and"
},
{
"id": 756,
"logprob": -7.4882812,
"text": "has"
},
{
"id": 1023,
"logprob": -5.0507812,
"text": "two"
},
{
"id": 274,
"logprob": -5.3164062,
"text": "c"
},
{
"id": 1446,
"logprob": -0.6694336,
"text": "ats"
},
{
"id": 29889,
"logprob": -0.9995117,
"text": "."
},
{
"id": 29871,
"logprob": -4.2421875,
"text": ""
}
],
"seed": null,
"tokens": [
{
"id": 6377,
"logprob": -0.14916992,
"special": false,
"text": "{\""
},
{
"id": 29888,
"logprob": -0.13598633,
"special": false,
"text": "f"
},
{
"id": 12935,
"logprob": -0.017669678,
"special": false,
"text": "irs"
},
{
"id": 29873,
"logprob": -0.00085639954,
"special": false,
"text": "t"
},
{
"id": 1170,
"logprob": -0.0054016113,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.13549805,
"special": false,
"text": "\":\""
},
{
"id": 19504,
"logprob": -0.8852539,
"special": false,
"text": "David"
},
{
"id": 3284,
"logprob": -0.16394043,
"special": false,
"text": "\",\""
},
{
"id": 4230,
"logprob": -0.020492554,
"special": false,
"text": "last"
},
{
"id": 1170,
"logprob": -0.0013818741,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.0067749023,
"special": false,
"text": "\":\""
},
{
"id": 29950,
"logprob": -0.11578369,
"special": false,
"text": "H"
},
{
"id": 14339,
"logprob": -0.004131317,
"special": false,
"text": "olt"
},
{
"id": 29920,
"logprob": -0.0033359528,
"special": false,
"text": "z"
},
{
"id": 3284,
"logprob": -0.20471191,
"special": false,
"text": "\",\""
},
{
"id": 29882,
"logprob": -0.0069274902,
"special": false,
"text": "h"
},
{
"id": 20838,
"logprob": -0.19580078,
"special": false,
"text": "obb"
},
{
"id": 29891,
"logprob": -2.2649765e-06,
"special": false,
"text": "y"
},
{
"id": 4710,
"logprob": -0.32080078,
"special": false,
"text": "\":\""
},
{
"id": 29911,
"logprob": -2.1035156,
"special": false,
"text": "T"
},
{
"id": 11003,
"logprob": -0.020767212,
"special": false,
"text": "rees"
},
{
"id": 3284,
"logprob": -0.6010742,
"special": false,
"text": "\",\""
},
{
"id": 29876,
"logprob": -0.57666016,
"special": false,
"text": "n"
},
{
"id": 398,
"logprob": -0.0061073303,
"special": false,
"text": "um"
},
{
"id": 29907,
"logprob": -0.45703125,
"special": false,
"text": "C"
},
{
"id": 1446,
"logprob": -0.0002872944,
"special": false,
"text": "ats"
},
{
"id": 1115,
"logprob": -0.0021018982,
"special": false,
"text": "\":"
},
{
"id": 29906,
"logprob": -0.08996582,
"special": false,
"text": "2"
},
{
"id": 29913,
"logprob": -0.021697998,
"special": false,
"text": "}"
},
{
"id": 2,
"logprob": 0.0,
"special": true,
"text": "</s>"
}
],
"top_tokens": null
},
"generated_text": "{\"firstName\":\"David\",\"lastName\":\"Holtz\",\"hobby\":\"Trees\",\"numCats\":2}"
}
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1024,
"logprob": -10.578125,
"text": "name"
},
{
"id": 29901,
"logprob": -3.03125,
"text": ":"
},
{
"id": 13260,
"logprob": -9.171875,
"text": "dav"
},
{
"id": 333,
"logprob": -0.04244995,
"text": "id"
},
{
"id": 29889,
"logprob": -2.4863281,
"text": "."
},
{
"id": 4876,
"logprob": -10.7890625,
"text": "email"
},
{
"id": 29901,
"logprob": -0.32714844,
"text": ":"
},
{
"id": 259,
"logprob": -9.4921875,
"text": " "
}
],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7685547,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.2376709,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.01008606,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64160156,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.5,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46557617,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5341797,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.5361328,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.00088739395,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.0022907257,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1024,
"logprob": -10.578125,
"text": "name"
},
{
"id": 29901,
"logprob": -3.0332031,
"text": ":"
},
{
"id": 13260,
"logprob": -9.171875,
"text": "dav"
},
{
"id": 333,
"logprob": -0.04257202,
"text": "id"
},
{
"id": 29889,
"logprob": -2.4785156,
"text": "."
},
{
"id": 4876,
"logprob": -10.7890625,
"text": "email"
},
{
"id": 29901,
"logprob": -0.32495117,
"text": ":"
},
{
"id": 259,
"logprob": -9.4921875,
"text": " "
}
],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7709961,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.23840332,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.00995636,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64208984,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.4970703,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46533203,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5336914,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.5361328,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.00088739395,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.0022735596,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1024,
"logprob": -10.578125,
"text": "name"
},
{
"id": 29901,
"logprob": -3.0332031,
"text": ":"
},
{
"id": 13260,
"logprob": -9.171875,
"text": "dav"
},
{
"id": 333,
"logprob": -0.04257202,
"text": "id"
},
{
"id": 29889,
"logprob": -2.4785156,
"text": "."
},
{
"id": 4876,
"logprob": -10.7890625,
"text": "email"
},
{
"id": 29901,
"logprob": -0.32495117,
"text": ":"
},
{
"id": 259,
"logprob": -9.4921875,
"text": " "
}
],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7709961,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.23840332,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.00995636,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64208984,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.4970703,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46533203,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5336914,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.5361328,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.00088739395,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.0022735596,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1024,
"logprob": -10.578125,
"text": "name"
},
{
"id": 29901,
"logprob": -3.0332031,
"text": ":"
},
{
"id": 13260,
"logprob": -9.171875,
"text": "dav"
},
{
"id": 333,
"logprob": -0.04257202,
"text": "id"
},
{
"id": 29889,
"logprob": -2.4785156,
"text": "."
},
{
"id": 4876,
"logprob": -10.7890625,
"text": "email"
},
{
"id": 29901,
"logprob": -0.32495117,
"text": ":"
},
{
"id": 259,
"logprob": -9.4921875,
"text": " "
}
],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7709961,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.23840332,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.00995636,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64208984,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.4970703,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46533203,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5336914,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.5361328,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.00088739395,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.0022735596,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
}
]
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 806,
"logprob": -11.890625,
"text": "Wh"
},
{
"id": 1446,
"logprob": -3.6699219,
"text": "ats"
},
{
"id": 2921,
"logprob": -7.8203125,
"text": "Go"
},
{
"id": 468,
"logprob": -8.0703125,
"text": "og"
},
{
"id": 793,
"logprob": -2.1875,
"text": "les"
},
{
"id": 16332,
"logprob": -9.7109375,
"text": "DNS"
}
],
"seed": null,
"tokens": [
{
"id": 29946,
"logprob": -1.4765625,
"special": false,
"text": "4"
},
{
"id": 29906,
"logprob": -0.9199219,
"special": false,
"text": "2"
},
{
"id": 29889,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 29896,
"logprob": -1.1367188,
"special": false,
"text": "1"
},
{
"id": 29889,
"logprob": -1.4648438,
"special": false,
"text": "."
},
{
"id": 29896,
"logprob": -0.40722656,
"special": false,
"text": "1"
},
{
"id": 29889,
"logprob": -0.17419434,
"special": false,
"text": "."
},
{
"id": 29896,
"logprob": -0.20251465,
"special": false,
"text": "1"
},
{
"id": 29900,
"logprob": -1.5527344,
"special": false,
"text": "0"
},
{
"id": 29896,
"logprob": -1.3710938,
"special": false,
"text": "1"
}
],
"top_tokens": null
},
"generated_text": "42.1.1.101"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7685547,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.33666992,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.009979248,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64208984,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.4970703,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46533203,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5336914,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.53759766,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.0008878708,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.002275467,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
}
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
def flash_llama_grammar_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_grammar(flash_llama_grammar_handle):
await flash_llama_grammar_handle.health(300)
return flash_llama_grammar_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"Whats Googles DNS",
max_new_tokens=10,
decoder_input_details=True,
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
},
)
assert response.details.generated_tokens == 10
assert response.generated_text == "42.1.1.101"
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"info: david holtz like trees and has two cats. ",
max_new_tokens=100,
decoder_input_details=True,
seed=0,
grammar={
"type": GrammarType.Json, # "json"
"value": json.dumps(
{
"type": "object",
"$id": "https://example.com/person.schema.json",
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "Person",
"properties": {
"firstName": {
"type": "string",
"description": "The person'''s first name.",
},
"lastName": {
"type": "string",
"description": "The person'''s last name.",
},
"hobby": {
"description": "The person'''s hobby.",
"type": "string",
},
"numCats": {
"description": "The number of cats the person has.",
"type": "integer",
"minimum": 0,
},
},
"required": ["firstName", "lastName", "hobby", "numCats"],
}
),
},
)
assert response.details.generated_tokens == 30
assert (
response.generated_text
== '{"firstName":"David","lastName":"Holtz","hobby":"Trees","numCats":2}'
)
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_load(
flash_llama_grammar, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_grammar,
"name: david. email: ",
max_new_tokens=10,
n=4,
stop_sequences=[".com"],
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
},
)
assert len(responses) == 4
expected = "123456@gmail.com"
for response in responses:
assert response.generated_text == expected
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
# this is the same as the above test, but only fires off a single request
# this is only to ensure that the parallel and single inference produce the same result
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_single_load_instance(
flash_llama_grammar, generate_load, response_snapshot
):
response = await flash_llama_grammar.generate(
"name: david. email: ",
max_new_tokens=10,
stop_sequences=[".com"],
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
},
)
# assert response.details.generated_tokens == 30
assert response.generated_text == "123456@gmail.com"
assert response == response_snapshot
...@@ -382,6 +382,11 @@ struct Args { ...@@ -382,6 +382,11 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
tokenizer_config_path: Option<String>, tokenizer_config_path: Option<String>,
/// Disable outlines grammar constrained generation.
/// This is a feature that allows you to generate text that follows a specific grammar.
#[clap(long, env)]
disable_grammar_support: bool,
/// Display a lot of information about your runtime environment /// Display a lot of information about your runtime environment
#[clap(long, short, action)] #[clap(long, short, action)]
env: bool, env: bool,
...@@ -1051,6 +1056,11 @@ fn spawn_webserver( ...@@ -1051,6 +1056,11 @@ fn spawn_webserver(
args.model_id, args.model_id,
]; ];
// Grammar support
if args.disable_grammar_support {
router_args.push("--disable-grammar-support".to_string());
}
// Tokenizer config path // Tokenizer config path
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path { if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
router_args.push("--tokenizer-config-path".to_string()); router_args.push("--tokenizer-config-path".to_string());
......
...@@ -51,6 +51,12 @@ message ClearCacheRequest { ...@@ -51,6 +51,12 @@ message ClearCacheRequest {
/// Empty response /// Empty response
message ClearCacheResponse {} message ClearCacheResponse {}
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2;
}
message NextTokenChooserParameters { message NextTokenChooserParameters {
/// exponential scaling output probability distribution /// exponential scaling output probability distribution
float temperature = 1; float temperature = 1;
...@@ -70,6 +76,10 @@ message NextTokenChooserParameters { ...@@ -70,6 +76,10 @@ message NextTokenChooserParameters {
float frequency_penalty = 9; float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models" /// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8; bool watermark = 8;
/// grammar (applied if not empty)
string grammar = 10;
/// grammar type
GrammarType grammar_type = 11;
} }
message StoppingCriteriaParameters { message StoppingCriteriaParameters {
......
...@@ -128,6 +128,8 @@ impl Client { ...@@ -128,6 +128,8 @@ impl Client {
repetition_penalty: 1.2, repetition_penalty: 1.2,
frequency_penalty: 0.1, frequency_penalty: 0.1,
watermark: true, watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate, max_new_tokens: max_total_tokens - truncate,
......
...@@ -9,8 +9,8 @@ pub use client::Client; ...@@ -9,8 +9,8 @@ pub use client::Client;
pub use pb::generate::v2::HealthResponse; pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo; pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{ pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
Request, StoppingCriteriaParameters, Tokens, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
}; };
pub use sharded_client::ShardedClient; pub use sharded_client::ShardedClient;
use thiserror::Error; use thiserror::Error;
......
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::GrammarType as ProtoGrammarType;
use text_generation_client::{ use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
}; };
...@@ -45,6 +46,8 @@ impl Health { ...@@ -45,6 +46,8 @@ impl Health {
repetition_penalty: 1.0, repetition_penalty: 1.0,
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1, max_new_tokens: 1,
......
...@@ -45,6 +45,43 @@ impl HubTokenizerConfig { ...@@ -45,6 +45,43 @@ impl HubTokenizerConfig {
} }
} }
mod json_object_or_string_to_string {
use serde::{Deserialize, Deserializer};
use serde_json::Value;
// A custom deserializer that treats both strings and objects as strings.
// This provides flexibility with input formats for the 'grammar' field.
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(s),
// Safely handle serialization and return an error if it fails
Value::Object(o) => {
serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string()))
}
_ => Err(serde::de::Error::custom(
"expected string or object for grammar",
)),
}
}
}
#[derive(Clone, Debug, Deserialize)]
#[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType {
#[serde(
rename = "json",
deserialize_with = "json_object_or_string_to_string::deserialize"
)]
Json(String),
#[serde(rename = "regex")]
Regex(String),
}
mod token_serde { mod token_serde {
use super::*; use super::*;
use serde::de; use serde::de;
...@@ -201,6 +238,8 @@ pub(crate) struct GenerateParameters { ...@@ -201,6 +238,8 @@ pub(crate) struct GenerateParameters {
#[serde(default)] #[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>, pub top_n_tokens: Option<u32>,
#[serde(default)]
pub grammar: Option<GrammarType>,
} }
fn default_max_new_tokens() -> Option<u32> { fn default_max_new_tokens() -> Option<u32> {
...@@ -226,6 +265,7 @@ fn default_parameters() -> GenerateParameters { ...@@ -226,6 +265,7 @@ fn default_parameters() -> GenerateParameters {
decoder_input_details: false, decoder_input_details: false,
seed: None, seed: None,
top_n_tokens: None, top_n_tokens: None,
grammar: None,
} }
} }
......
...@@ -75,6 +75,8 @@ struct Args { ...@@ -75,6 +75,8 @@ struct Args {
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)] #[clap(long, env, default_value_t = false)]
messages_api_enabled: bool, messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
} }
#[tokio::main] #[tokio::main]
...@@ -108,6 +110,7 @@ async fn main() -> Result<(), RouterError> { ...@@ -108,6 +110,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken, ngrok_authtoken,
ngrok_edge, ngrok_edge,
messages_api_enabled, messages_api_enabled,
disable_grammar_support,
} = args; } = args;
// Launch Tokio runtime // Launch Tokio runtime
...@@ -359,6 +362,7 @@ async fn main() -> Result<(), RouterError> { ...@@ -359,6 +362,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_edge, ngrok_edge,
tokenizer_config, tokenizer_config,
messages_api_enabled, messages_api_enabled,
disable_grammar_support,
) )
.await?; .await?;
Ok(()) Ok(())
......
...@@ -343,7 +343,9 @@ enum QueueCommand { ...@@ -343,7 +343,9 @@ enum QueueCommand {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
use tracing::info_span; use tracing::info_span;
fn default_entry() -> ( fn default_entry() -> (
...@@ -354,7 +356,7 @@ mod tests { ...@@ -354,7 +356,7 @@ mod tests {
let entry = Entry { let entry = Entry {
request: ValidGenerateRequest { request: ValidGenerateRequest {
inputs: "".to_string(), inputs: String::new(),
input_length: 0, input_length: 0,
truncate: 0, truncate: 0,
decoder_input_details: false, decoder_input_details: false,
...@@ -368,6 +370,8 @@ mod tests { ...@@ -368,6 +370,8 @@ mod tests {
repetition_penalty: 0.0, repetition_penalty: 0.0,
frequency_penalty: 0.0, frequency_penalty: 0.0,
watermark: false, watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
}, },
stopping_parameters: StoppingCriteriaParameters { stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false, ignore_eos_token: false,
......
...@@ -614,6 +614,7 @@ async fn chat_completions( ...@@ -614,6 +614,7 @@ async fn chat_completions(
decoder_input_details: !stream, decoder_input_details: !stream,
seed, seed,
top_n_tokens: None, top_n_tokens: None,
grammar: None,
}, },
}; };
...@@ -779,6 +780,7 @@ pub async fn run( ...@@ -779,6 +780,7 @@ pub async fn run(
ngrok_edge: Option<String>, ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
messages_api_enabled: bool, messages_api_enabled: bool,
grammar_support: bool,
) -> Result<(), axum::BoxError> { ) -> Result<(), axum::BoxError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
...@@ -840,6 +842,7 @@ pub async fn run( ...@@ -840,6 +842,7 @@ pub async fn run(
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
grammar_support,
); );
let generation_health = Arc::new(AtomicBool::new(false)); let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health.clone()); let health_ext = Health::new(client.clone(), generation_health.clone());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment