Unverified Commit cef0553d authored by drbh's avatar drbh Committed by GitHub
Browse files

Outlines guided generation (#1539)

This WIP PR starts to add grammar support via outlines, currently this
PR supports very simple regex grammars and does not optimize for
precompiling or caching grammar fsm's.

todo:
- [X] add simple outlines guidance to `NextTokenChooser`
- [X] update protos for grammar
- [X] update generation params API
- [X] constrain simple grammar
- [ ] support parsing more complex grammar into fsm
- [ ] support all outline support grammar types
- [ ] explore optimizations to avoid recompiling grammars

guided request
```bash
curl -s 'http://localhost:3000/generate' \
--header 'Content-Type: application/json' \
--data-raw '{
    "inputs": "make an email for david: \n",
    "parameters": {
        "max_new_tokens": 6,
        "grammar": "[\\w-]+@([\\w-]+\\.)+[\\w-]+"
    }
}' | jq
```
response
```json
{
  "generated_text": "david@example.com"
}
```

unguided request
```bash
curl -s 'http://localhost:3000/generate' \
--header 'Content-Type: application/json' \
--data '{
    "inputs": "make an email for david: \n",
    "parameters": {
        "max_new_tokens": 6
    }
}' | jq
```
response
```json
{
  "generated_text": "    email = 'david"
}
```
parent 4c2848b2
......@@ -8,7 +8,7 @@ use crate::app::App;
use crate::event::Event;
use crossterm::ExecutableCommand;
use std::io;
use text_generation_client::{NextTokenChooserParameters, ShardedClient};
use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient};
use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend;
......@@ -45,6 +45,8 @@ pub async fn run(
repetition_penalty: repetition_penalty.unwrap_or(1.0),
frequency_penalty: frequency_penalty.unwrap_or(0.0),
watermark,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
};
// Initialize terminal properties
......
......@@ -10,6 +10,7 @@ from text_generation.types import (
Response,
Request,
Parameters,
Grammar,
)
from text_generation.errors import parse_error
......@@ -76,6 +77,7 @@ class Client:
watermark: bool = False,
decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
grammar: Optional[Grammar] = None,
) -> Response:
"""
Given a prompt, generate the following text
......@@ -138,6 +140,7 @@ class Client:
watermark=watermark,
decoder_input_details=decoder_input_details,
top_n_tokens=top_n_tokens,
grammar=grammar,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
......@@ -169,6 +172,7 @@ class Client:
typical_p: Optional[float] = None,
watermark: bool = False,
top_n_tokens: Optional[int] = None,
grammar: Optional[Grammar] = None,
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
......@@ -227,6 +231,7 @@ class Client:
typical_p=typical_p,
watermark=watermark,
top_n_tokens=top_n_tokens,
grammar=grammar,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
......@@ -326,6 +331,7 @@ class AsyncClient:
watermark: bool = False,
decoder_input_details: bool = False,
top_n_tokens: Optional[int] = None,
grammar: Optional[Grammar] = None,
) -> Response:
"""
Given a prompt, generate the following text asynchronously
......@@ -370,6 +376,7 @@ class AsyncClient:
Returns:
Response: generated response
"""
# Validate parameters
parameters = Parameters(
best_of=best_of,
......@@ -388,6 +395,7 @@ class AsyncClient:
typical_p=typical_p,
watermark=watermark,
top_n_tokens=top_n_tokens,
grammar=grammar,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
......@@ -417,6 +425,7 @@ class AsyncClient:
typical_p: Optional[float] = None,
watermark: bool = False,
top_n_tokens: Optional[int] = None,
grammar: Optional[Grammar] = None,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
......@@ -475,6 +484,7 @@ class AsyncClient:
typical_p=typical_p,
watermark=watermark,
top_n_tokens=top_n_tokens,
grammar=grammar,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)
......
from enum import Enum
from pydantic import BaseModel, validator
from typing import Optional, List
from typing import Optional, List, Union
from text_generation.errors import ValidationError
# enum for grammar type
class GrammarType(str, Enum):
Json = "json"
Regex = "regex"
# Grammar type and value
class Grammar(BaseModel):
# Grammar type
type: GrammarType
# Grammar value
value: Union[str, dict]
class Parameters(BaseModel):
# Activate logits sampling
do_sample: bool = False
......@@ -41,6 +55,8 @@ class Parameters(BaseModel):
decoder_input_details: bool = False
# Return the N most likely tokens at each step
top_n_tokens: Optional[int] = None
# grammar to use for generation
grammar: Optional[Grammar] = None
@validator("best_of")
def valid_best_of(cls, field_value, values):
......@@ -109,6 +125,14 @@ class Parameters(BaseModel):
raise ValidationError("`top_n_tokens` must be strictly positive")
return v
@validator("grammar")
def valid_grammar(cls, v):
if v is not None:
if v.type == GrammarType.Regex and not v.value:
raise ValidationError("`value` cannot be empty for `regex` grammar")
if v.type == GrammarType.Json and not v.value:
raise ValidationError("`value` cannot be empty for `json` grammar")
return v
class Request(BaseModel):
# Prompt
......@@ -157,7 +181,7 @@ class Token(BaseModel):
# Token text
text: str
# Logprob
logprob: float
logprob: Optional[float] = None
# Is the token a special token
# Can be used to ignore tokens when concatenating
special: bool
......
......@@ -378,6 +378,14 @@ Options:
[env: TOKENIZER_CONFIG_PATH=]
```
## DISABLE_GRAMMAR_SUPPORT
```shell
--disable-grammar-support
Disable outlines grammar constrained generation. This is a feature that allows you to generate text that follows a specific grammar
[env: DISABLE_GRAMMAR_SUPPORT=]
```
## ENV
```shell
......
......@@ -16,7 +16,14 @@ from syrupy.extensions.json import JSONSnapshotExtension
from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError
from text_generation import AsyncClient
from text_generation.types import Response, Details, InputToken, Token, BestOfSequence
from text_generation.types import (
Response,
Details,
InputToken,
Token,
BestOfSequence,
Grammar,
)
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
......@@ -224,6 +231,7 @@ def launcher(event_loop):
quantize: Optional[str] = None,
trust_remote_code: bool = False,
use_flash_attention: bool = True,
disable_grammar_support: bool = False,
dtype: Optional[str] = None,
):
port = random.randint(8000, 10_000)
......@@ -247,6 +255,8 @@ def launcher(event_loop):
env = os.environ
if disable_grammar_support:
args.append("--disable-grammar-support")
if num_shard is not None:
args.extend(["--num-shard", str(num_shard)])
if quantize is not None:
......@@ -287,12 +297,15 @@ def launcher(event_loop):
quantize: Optional[str] = None,
trust_remote_code: bool = False,
use_flash_attention: bool = True,
disable_grammar_support: bool = False,
dtype: Optional[str] = None,
):
port = random.randint(8000, 10_000)
args = ["--model-id", model_id, "--env"]
if disable_grammar_support:
args.append("--disable-grammar-support")
if num_shard is not None:
args.extend(["--num-shard", str(num_shard)])
if quantize is not None:
......@@ -370,11 +383,22 @@ def launcher(event_loop):
@pytest.fixture(scope="module")
def generate_load():
async def generate_load_inner(
client: AsyncClient, prompt: str, max_new_tokens: int, n: int
client: AsyncClient,
prompt: str,
max_new_tokens: int,
n: int,
seed: Optional[int] = None,
grammar: Optional[Grammar] = None,
stop_sequences: Optional[List[str]] = None,
) -> List[Response]:
futures = [
client.generate(
prompt, max_new_tokens=max_new_tokens, decoder_input_details=True
prompt,
max_new_tokens=max_new_tokens,
decoder_input_details=True,
seed=seed,
grammar=grammar,
stop_sequences=stop_sequences,
)
for _ in range(n)
]
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -13.90625,
"text": "Test"
},
{
"id": 2009,
"logprob": -12.328125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.0566406,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -1.5253906,
"special": false,
"text": "\n"
},
{
"id": 29902,
"logprob": -2.7578125,
"special": false,
"text": "I"
},
{
"id": 4966,
"logprob": -1.9033203,
"special": false,
"text": " hope"
},
{
"id": 445,
"logprob": -0.5019531,
"special": false,
"text": " this"
},
{
"id": 6911,
"logprob": -0.21264648,
"special": false,
"text": " helps"
},
{
"id": 29991,
"logprob": -0.5991211,
"special": false,
"text": "!"
},
{
"id": 2803,
"logprob": -0.37475586,
"special": false,
"text": " Let"
},
{
"id": 592,
"logprob": -0.018463135,
"special": false,
"text": " me"
},
{
"id": 1073,
"logprob": -0.0008597374,
"special": false,
"text": " know"
}
],
"top_tokens": null
},
"generated_text": "\n\nI hope this helps! Let me know"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 30,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 5235,
"logprob": -10.0625,
"text": "info"
},
{
"id": 29901,
"logprob": -3.2324219,
"text": ":"
},
{
"id": 13260,
"logprob": -10.625,
"text": "dav"
},
{
"id": 333,
"logprob": -0.08276367,
"text": "id"
},
{
"id": 8753,
"logprob": -7.5273438,
"text": "hol"
},
{
"id": 17559,
"logprob": -3.8476562,
"text": "tz"
},
{
"id": 763,
"logprob": -10.140625,
"text": "like"
},
{
"id": 10697,
"logprob": -10.1953125,
"text": "trees"
},
{
"id": 322,
"logprob": -2.5742188,
"text": "and"
},
{
"id": 756,
"logprob": -7.4882812,
"text": "has"
},
{
"id": 1023,
"logprob": -5.0507812,
"text": "two"
},
{
"id": 274,
"logprob": -5.3164062,
"text": "c"
},
{
"id": 1446,
"logprob": -0.6694336,
"text": "ats"
},
{
"id": 29889,
"logprob": -0.9995117,
"text": "."
},
{
"id": 29871,
"logprob": -4.2421875,
"text": ""
}
],
"seed": null,
"tokens": [
{
"id": 6377,
"logprob": -0.14916992,
"special": false,
"text": "{\""
},
{
"id": 29888,
"logprob": -0.13598633,
"special": false,
"text": "f"
},
{
"id": 12935,
"logprob": -0.017669678,
"special": false,
"text": "irs"
},
{
"id": 29873,
"logprob": -0.00085639954,
"special": false,
"text": "t"
},
{
"id": 1170,
"logprob": -0.0054016113,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.13549805,
"special": false,
"text": "\":\""
},
{
"id": 19504,
"logprob": -0.8852539,
"special": false,
"text": "David"
},
{
"id": 3284,
"logprob": -0.16394043,
"special": false,
"text": "\",\""
},
{
"id": 4230,
"logprob": -0.020492554,
"special": false,
"text": "last"
},
{
"id": 1170,
"logprob": -0.0013818741,
"special": false,
"text": "Name"
},
{
"id": 4710,
"logprob": -0.0067749023,
"special": false,
"text": "\":\""
},
{
"id": 29950,
"logprob": -0.11578369,
"special": false,
"text": "H"
},
{
"id": 14339,
"logprob": -0.004131317,
"special": false,
"text": "olt"
},
{
"id": 29920,
"logprob": -0.0033359528,
"special": false,
"text": "z"
},
{
"id": 3284,
"logprob": -0.20471191,
"special": false,
"text": "\",\""
},
{
"id": 29882,
"logprob": -0.0069274902,
"special": false,
"text": "h"
},
{
"id": 20838,
"logprob": -0.19580078,
"special": false,
"text": "obb"
},
{
"id": 29891,
"logprob": -2.2649765e-06,
"special": false,
"text": "y"
},
{
"id": 4710,
"logprob": -0.32080078,
"special": false,
"text": "\":\""
},
{
"id": 29911,
"logprob": -2.1035156,
"special": false,
"text": "T"
},
{
"id": 11003,
"logprob": -0.020767212,
"special": false,
"text": "rees"
},
{
"id": 3284,
"logprob": -0.6010742,
"special": false,
"text": "\",\""
},
{
"id": 29876,
"logprob": -0.57666016,
"special": false,
"text": "n"
},
{
"id": 398,
"logprob": -0.0061073303,
"special": false,
"text": "um"
},
{
"id": 29907,
"logprob": -0.45703125,
"special": false,
"text": "C"
},
{
"id": 1446,
"logprob": -0.0002872944,
"special": false,
"text": "ats"
},
{
"id": 1115,
"logprob": -0.0021018982,
"special": false,
"text": "\":"
},
{
"id": 29906,
"logprob": -0.08996582,
"special": false,
"text": "2"
},
{
"id": 29913,
"logprob": -0.021697998,
"special": false,
"text": "}"
},
{
"id": 2,
"logprob": 0.0,
"special": true,
"text": "</s>"
}
],
"top_tokens": null
},
"generated_text": "{\"firstName\":\"David\",\"lastName\":\"Holtz\",\"hobby\":\"Trees\",\"numCats\":2}"
}
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1024,
"logprob": -10.578125,
"text": "name"
},
{
"id": 29901,
"logprob": -3.03125,
"text": ":"
},
{
"id": 13260,
"logprob": -9.171875,
"text": "dav"
},
{
"id": 333,
"logprob": -0.04244995,
"text": "id"
},
{
"id": 29889,
"logprob": -2.4863281,
"text": "."
},
{
"id": 4876,
"logprob": -10.7890625,
"text": "email"
},
{
"id": 29901,
"logprob": -0.32714844,
"text": ":"
},
{
"id": 259,
"logprob": -9.4921875,
"text": " "
}
],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7685547,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.2376709,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.01008606,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64160156,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.5,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46557617,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5341797,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.5361328,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.00088739395,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.0022907257,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1024,
"logprob": -10.578125,
"text": "name"
},
{
"id": 29901,
"logprob": -3.0332031,
"text": ":"
},
{
"id": 13260,
"logprob": -9.171875,
"text": "dav"
},
{
"id": 333,
"logprob": -0.04257202,
"text": "id"
},
{
"id": 29889,
"logprob": -2.4785156,
"text": "."
},
{
"id": 4876,
"logprob": -10.7890625,
"text": "email"
},
{
"id": 29901,
"logprob": -0.32495117,
"text": ":"
},
{
"id": 259,
"logprob": -9.4921875,
"text": " "
}
],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7709961,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.23840332,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.00995636,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64208984,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.4970703,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46533203,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5336914,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.5361328,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.00088739395,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.0022735596,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1024,
"logprob": -10.578125,
"text": "name"
},
{
"id": 29901,
"logprob": -3.0332031,
"text": ":"
},
{
"id": 13260,
"logprob": -9.171875,
"text": "dav"
},
{
"id": 333,
"logprob": -0.04257202,
"text": "id"
},
{
"id": 29889,
"logprob": -2.4785156,
"text": "."
},
{
"id": 4876,
"logprob": -10.7890625,
"text": "email"
},
{
"id": 29901,
"logprob": -0.32495117,
"text": ":"
},
{
"id": 259,
"logprob": -9.4921875,
"text": " "
}
],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7709961,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.23840332,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.00995636,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64208984,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.4970703,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46533203,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5336914,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.5361328,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.00088739395,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.0022735596,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 1024,
"logprob": -10.578125,
"text": "name"
},
{
"id": 29901,
"logprob": -3.0332031,
"text": ":"
},
{
"id": 13260,
"logprob": -9.171875,
"text": "dav"
},
{
"id": 333,
"logprob": -0.04257202,
"text": "id"
},
{
"id": 29889,
"logprob": -2.4785156,
"text": "."
},
{
"id": 4876,
"logprob": -10.7890625,
"text": "email"
},
{
"id": 29901,
"logprob": -0.32495117,
"text": ":"
},
{
"id": 259,
"logprob": -9.4921875,
"text": " "
}
],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7709961,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.23840332,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.00995636,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64208984,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.4970703,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46533203,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5336914,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.5361328,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.00088739395,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.0022735596,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
}
]
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 806,
"logprob": -11.890625,
"text": "Wh"
},
{
"id": 1446,
"logprob": -3.6699219,
"text": "ats"
},
{
"id": 2921,
"logprob": -7.8203125,
"text": "Go"
},
{
"id": 468,
"logprob": -8.0703125,
"text": "og"
},
{
"id": 793,
"logprob": -2.1875,
"text": "les"
},
{
"id": 16332,
"logprob": -9.7109375,
"text": "DNS"
}
],
"seed": null,
"tokens": [
{
"id": 29946,
"logprob": -1.4765625,
"special": false,
"text": "4"
},
{
"id": 29906,
"logprob": -0.9199219,
"special": false,
"text": "2"
},
{
"id": 29889,
"logprob": 0.0,
"special": false,
"text": "."
},
{
"id": 29896,
"logprob": -1.1367188,
"special": false,
"text": "1"
},
{
"id": 29889,
"logprob": -1.4648438,
"special": false,
"text": "."
},
{
"id": 29896,
"logprob": -0.40722656,
"special": false,
"text": "1"
},
{
"id": 29889,
"logprob": -0.17419434,
"special": false,
"text": "."
},
{
"id": 29896,
"logprob": -0.20251465,
"special": false,
"text": "1"
},
{
"id": 29900,
"logprob": -1.5527344,
"special": false,
"text": "0"
},
{
"id": 29896,
"logprob": -1.3710938,
"special": false,
"text": "1"
}
],
"top_tokens": null
},
"generated_text": "42.1.1.101"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 29896,
"logprob": -0.7685547,
"special": false,
"text": "1"
},
{
"id": 29906,
"logprob": -0.33666992,
"special": false,
"text": "2"
},
{
"id": 29941,
"logprob": -0.009979248,
"special": false,
"text": "3"
},
{
"id": 29946,
"logprob": -0.64208984,
"special": false,
"text": "4"
},
{
"id": 29945,
"logprob": -0.4970703,
"special": false,
"text": "5"
},
{
"id": 29953,
"logprob": -0.46533203,
"special": false,
"text": "6"
},
{
"id": 29992,
"logprob": -0.5336914,
"special": false,
"text": "@"
},
{
"id": 21980,
"logprob": -0.53759766,
"special": false,
"text": "gmail"
},
{
"id": 29889,
"logprob": -0.0008878708,
"special": false,
"text": "."
},
{
"id": 510,
"logprob": -0.002275467,
"special": false,
"text": "com"
}
],
"top_tokens": null
},
"generated_text": "123456@gmail.com"
}
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
def flash_llama_grammar_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_grammar(flash_llama_grammar_handle):
await flash_llama_grammar_handle.health(300)
return flash_llama_grammar_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"Whats Googles DNS",
max_new_tokens=10,
decoder_input_details=True,
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
},
)
assert response.details.generated_tokens == 10
assert response.generated_text == "42.1.1.101"
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
response = await flash_llama_grammar.generate(
"info: david holtz like trees and has two cats. ",
max_new_tokens=100,
decoder_input_details=True,
seed=0,
grammar={
"type": GrammarType.Json, # "json"
"value": json.dumps(
{
"type": "object",
"$id": "https://example.com/person.schema.json",
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "Person",
"properties": {
"firstName": {
"type": "string",
"description": "The person'''s first name.",
},
"lastName": {
"type": "string",
"description": "The person'''s last name.",
},
"hobby": {
"description": "The person'''s hobby.",
"type": "string",
},
"numCats": {
"description": "The number of cats the person has.",
"type": "integer",
"minimum": 0,
},
},
"required": ["firstName", "lastName", "hobby", "numCats"],
}
),
},
)
assert response.details.generated_tokens == 30
assert (
response.generated_text
== '{"firstName":"David","lastName":"Holtz","hobby":"Trees","numCats":2}'
)
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_load(
flash_llama_grammar, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_grammar,
"name: david. email: ",
max_new_tokens=10,
n=4,
stop_sequences=[".com"],
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
},
)
assert len(responses) == 4
expected = "123456@gmail.com"
for response in responses:
assert response.generated_text == expected
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
# this is the same as the above test, but only fires off a single request
# this is only to ensure that the parallel and single inference produce the same result
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_single_load_instance(
flash_llama_grammar, generate_load, response_snapshot
):
response = await flash_llama_grammar.generate(
"name: david. email: ",
max_new_tokens=10,
stop_sequences=[".com"],
seed=0,
grammar={
"type": GrammarType.Regex, # "regex"
"value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+", # email regex
},
)
# assert response.details.generated_tokens == 30
assert response.generated_text == "123456@gmail.com"
assert response == response_snapshot
......@@ -382,6 +382,11 @@ struct Args {
#[clap(long, env)]
tokenizer_config_path: Option<String>,
/// Disable outlines grammar constrained generation.
/// This is a feature that allows you to generate text that follows a specific grammar.
#[clap(long, env)]
disable_grammar_support: bool,
/// Display a lot of information about your runtime environment
#[clap(long, short, action)]
env: bool,
......@@ -1051,6 +1056,11 @@ fn spawn_webserver(
args.model_id,
];
// Grammar support
if args.disable_grammar_support {
router_args.push("--disable-grammar-support".to_string());
}
// Tokenizer config path
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
router_args.push("--tokenizer-config-path".to_string());
......
......@@ -51,6 +51,12 @@ message ClearCacheRequest {
/// Empty response
message ClearCacheResponse {}
enum GrammarType {
GRAMMAR_TYPE_NONE = 0;
GRAMMAR_TYPE_JSON = 1;
GRAMMAR_TYPE_REGEX = 2;
}
message NextTokenChooserParameters {
/// exponential scaling output probability distribution
float temperature = 1;
......@@ -70,6 +76,10 @@ message NextTokenChooserParameters {
float frequency_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
/// grammar (applied if not empty)
string grammar = 10;
/// grammar type
GrammarType grammar_type = 11;
}
message StoppingCriteriaParameters {
......
......@@ -128,6 +128,8 @@ impl Client {
repetition_penalty: 1.2,
frequency_penalty: 0.1,
watermark: true,
grammar: String::new(),
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
......
......@@ -9,8 +9,8 @@ pub use client::Client;
pub use pb::generate::v2::HealthResponse;
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
Request, StoppingCriteriaParameters, Tokens,
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;
......
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use text_generation_client::GrammarType as ProtoGrammarType;
use text_generation_client::{
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
};
......@@ -45,6 +46,8 @@ impl Health {
repetition_penalty: 1.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
......
......@@ -45,6 +45,43 @@ impl HubTokenizerConfig {
}
}
mod json_object_or_string_to_string {
use serde::{Deserialize, Deserializer};
use serde_json::Value;
// A custom deserializer that treats both strings and objects as strings.
// This provides flexibility with input formats for the 'grammar' field.
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(s),
// Safely handle serialization and return an error if it fails
Value::Object(o) => {
serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string()))
}
_ => Err(serde::de::Error::custom(
"expected string or object for grammar",
)),
}
}
}
#[derive(Clone, Debug, Deserialize)]
#[serde(tag = "type", content = "value")]
pub(crate) enum GrammarType {
#[serde(
rename = "json",
deserialize_with = "json_object_or_string_to_string::deserialize"
)]
Json(String),
#[serde(rename = "regex")]
Regex(String),
}
mod token_serde {
use super::*;
use serde::de;
......@@ -201,6 +238,8 @@ pub(crate) struct GenerateParameters {
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
pub top_n_tokens: Option<u32>,
#[serde(default)]
pub grammar: Option<GrammarType>,
}
fn default_max_new_tokens() -> Option<u32> {
......@@ -226,6 +265,7 @@ fn default_parameters() -> GenerateParameters {
decoder_input_details: false,
seed: None,
top_n_tokens: None,
grammar: None,
}
}
......
......@@ -75,6 +75,8 @@ struct Args {
ngrok_edge: Option<String>,
#[clap(long, env, default_value_t = false)]
messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)]
disable_grammar_support: bool,
}
#[tokio::main]
......@@ -108,6 +110,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken,
ngrok_edge,
messages_api_enabled,
disable_grammar_support,
} = args;
// Launch Tokio runtime
......@@ -359,6 +362,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_edge,
tokenizer_config,
messages_api_enabled,
disable_grammar_support,
)
.await?;
Ok(())
......
......@@ -343,7 +343,9 @@ enum QueueCommand {
#[cfg(test)]
mod tests {
use super::*;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use text_generation_client::{
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
};
use tracing::info_span;
fn default_entry() -> (
......@@ -354,7 +356,7 @@ mod tests {
let entry = Entry {
request: ValidGenerateRequest {
inputs: "".to_string(),
inputs: String::new(),
input_length: 0,
truncate: 0,
decoder_input_details: false,
......@@ -368,6 +370,8 @@ mod tests {
repetition_penalty: 0.0,
frequency_penalty: 0.0,
watermark: false,
grammar: String::new(),
grammar_type: ProtoGrammarType::None as i32,
},
stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false,
......
......@@ -614,6 +614,7 @@ async fn chat_completions(
decoder_input_details: !stream,
seed,
top_n_tokens: None,
grammar: None,
},
};
......@@ -779,6 +780,7 @@ pub async fn run(
ngrok_edge: Option<String>,
tokenizer_config: HubTokenizerConfig,
messages_api_enabled: bool,
grammar_support: bool,
) -> Result<(), axum::BoxError> {
// OpenAPI documentation
#[derive(OpenApi)]
......@@ -840,6 +842,7 @@ pub async fn run(
max_top_n_tokens,
max_input_length,
max_total_tokens,
grammar_support,
);
let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health.clone());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment