"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "4fcf11d8006db4f0e603d131d1c462c00bc6fead"
Unverified Commit de6cb15f authored by drbh's avatar drbh Committed by GitHub
Browse files

fix: improve tool type, bump pydantic and outlines (#1650)

This PR resolves a couple 

- [X] adjusts the tool response to align with openai's tools response
type
- [X] bumps pydantic to `2.6.4` in all apps (resolves dependency issue
when running tests)
- [X] bump `outlines` version and fix import for new name
parent 4f09c80c
...@@ -377,6 +377,12 @@ version = "1.0.0" ...@@ -377,6 +377,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfg_aliases"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e"
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.1" version = "4.5.1"
...@@ -545,7 +551,7 @@ version = "3.4.2" ...@@ -545,7 +551,7 @@ version = "3.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b467862cc8610ca6fc9a1532d7777cee0804e678ab45410897b9396495994a0b" checksum = "b467862cc8610ca6fc9a1532d7777cee0804e678ab45410897b9396495994a0b"
dependencies = [ dependencies = [
"nix", "nix 0.27.1",
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
...@@ -1613,6 +1619,18 @@ dependencies = [ ...@@ -1613,6 +1619,18 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "nix"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4"
dependencies = [
"bitflags 2.4.2",
"cfg-if",
"cfg_aliases",
"libc",
]
[[package]] [[package]]
name = "nohash-hasher" name = "nohash-hasher"
version = "0.2.0" version = "0.2.0"
...@@ -3000,7 +3018,7 @@ dependencies = [ ...@@ -3000,7 +3018,7 @@ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
"float_eq", "float_eq",
"nix", "nix 0.28.0",
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",
......
This diff is collapsed.
...@@ -12,7 +12,7 @@ repository = "https://github.com/huggingface/text-generation-inference" ...@@ -12,7 +12,7 @@ repository = "https://github.com/huggingface/text-generation-inference"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.7" python = "^3.7"
pydantic = "> 1.10, < 3" pydantic = "> 2, < 3"
aiohttp = "^3.8" aiohttp = "^3.8"
huggingface-hub = ">= 0.12, < 1.0" huggingface-hub = ">= 0.12, < 1.0"
......
from enum import Enum from enum import Enum
from pydantic import BaseModel, validator from pydantic import BaseModel, field_validator
from typing import Optional, List, Union, Any from typing import Optional, List, Union, Any
from text_generation.errors import ValidationError from text_generation.errors import ValidationError
...@@ -32,7 +32,7 @@ class Message(BaseModel): ...@@ -32,7 +32,7 @@ class Message(BaseModel):
# Role of the message sender # Role of the message sender
role: str role: str
# Content of the message # Content of the message
content: Optional[str] content: Optional[str] = None
# Optional name of the message sender # Optional name of the message sender
name: Optional[str] = None name: Optional[str] = None
# Tool calls associated with the chat completion # Tool calls associated with the chat completion
...@@ -56,7 +56,7 @@ class ChatCompletionComplete(BaseModel): ...@@ -56,7 +56,7 @@ class ChatCompletionComplete(BaseModel):
# Reason for completion # Reason for completion
finish_reason: str finish_reason: str
# Usage details of the chat completion # Usage details of the chat completion
usage: Any usage: Optional[Any] = None
class Function(BaseModel): class Function(BaseModel):
...@@ -73,7 +73,7 @@ class ChoiceDeltaToolCall(BaseModel): ...@@ -73,7 +73,7 @@ class ChoiceDeltaToolCall(BaseModel):
class ChoiceDelta(BaseModel): class ChoiceDelta(BaseModel):
role: str role: str
content: Optional[str] content: Optional[str] = None
tool_calls: Optional[ChoiceDeltaToolCall] tool_calls: Optional[ChoiceDeltaToolCall]
...@@ -176,74 +176,74 @@ class Parameters(BaseModel): ...@@ -176,74 +176,74 @@ class Parameters(BaseModel):
# grammar to use for generation # grammar to use for generation
grammar: Optional[Grammar] = None grammar: Optional[Grammar] = None
@validator("best_of") @field_validator("best_of")
def valid_best_of(cls, field_value, values): def valid_best_of(cls, field_value, values):
if field_value is not None: if field_value is not None:
if field_value <= 0: if field_value <= 0:
raise ValidationError("`best_of` must be strictly positive") raise ValidationError("`best_of` must be strictly positive")
if field_value > 1 and values["seed"] is not None: if field_value > 1 and values.data["seed"] is not None:
raise ValidationError("`seed` must not be set when `best_of` is > 1") raise ValidationError("`seed` must not be set when `best_of` is > 1")
sampling = ( sampling = (
values["do_sample"] values.data["do_sample"]
| (values["temperature"] is not None) | (values.data["temperature"] is not None)
| (values["top_k"] is not None) | (values.data["top_k"] is not None)
| (values["top_p"] is not None) | (values.data["top_p"] is not None)
| (values["typical_p"] is not None) | (values.data["typical_p"] is not None)
) )
if field_value > 1 and not sampling: if field_value > 1 and not sampling:
raise ValidationError("you must use sampling when `best_of` is > 1") raise ValidationError("you must use sampling when `best_of` is > 1")
return field_value return field_value
@validator("repetition_penalty") @field_validator("repetition_penalty")
def valid_repetition_penalty(cls, v): def valid_repetition_penalty(cls, v):
if v is not None and v <= 0: if v is not None and v <= 0:
raise ValidationError("`repetition_penalty` must be strictly positive") raise ValidationError("`repetition_penalty` must be strictly positive")
return v return v
@validator("seed") @field_validator("seed")
def valid_seed(cls, v): def valid_seed(cls, v):
if v is not None and v < 0: if v is not None and v < 0:
raise ValidationError("`seed` must be positive") raise ValidationError("`seed` must be positive")
return v return v
@validator("temperature") @field_validator("temperature")
def valid_temp(cls, v): def valid_temp(cls, v):
if v is not None and v <= 0: if v is not None and v <= 0:
raise ValidationError("`temperature` must be strictly positive") raise ValidationError("`temperature` must be strictly positive")
return v return v
@validator("top_k") @field_validator("top_k")
def valid_top_k(cls, v): def valid_top_k(cls, v):
if v is not None and v <= 0: if v is not None and v <= 0:
raise ValidationError("`top_k` must be strictly positive") raise ValidationError("`top_k` must be strictly positive")
return v return v
@validator("top_p") @field_validator("top_p")
def valid_top_p(cls, v): def valid_top_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0): if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`top_p` must be > 0.0 and < 1.0") raise ValidationError("`top_p` must be > 0.0 and < 1.0")
return v return v
@validator("truncate") @field_validator("truncate")
def valid_truncate(cls, v): def valid_truncate(cls, v):
if v is not None and v <= 0: if v is not None and v <= 0:
raise ValidationError("`truncate` must be strictly positive") raise ValidationError("`truncate` must be strictly positive")
return v return v
@validator("typical_p") @field_validator("typical_p")
def valid_typical_p(cls, v): def valid_typical_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0): if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0") raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v return v
@validator("top_n_tokens") @field_validator("top_n_tokens")
def valid_top_n_tokens(cls, v): def valid_top_n_tokens(cls, v):
if v is not None and v <= 0: if v is not None and v <= 0:
raise ValidationError("`top_n_tokens` must be strictly positive") raise ValidationError("`top_n_tokens` must be strictly positive")
return v return v
@validator("grammar") @field_validator("grammar")
def valid_grammar(cls, v): def valid_grammar(cls, v):
if v is not None: if v is not None:
if v.type == GrammarType.Regex and not v.value: if v.type == GrammarType.Regex and not v.value:
...@@ -261,15 +261,15 @@ class Request(BaseModel): ...@@ -261,15 +261,15 @@ class Request(BaseModel):
# Whether to stream output tokens # Whether to stream output tokens
stream: bool = False stream: bool = False
@validator("inputs") @field_validator("inputs")
def valid_input(cls, v): def valid_input(cls, v):
if not v: if not v:
raise ValidationError("`inputs` cannot be empty") raise ValidationError("`inputs` cannot be empty")
return v return v
@validator("stream") @field_validator("stream")
def valid_best_of_stream(cls, field_value, values): def valid_best_of_stream(cls, field_value, values):
parameters = values["parameters"] parameters = values.data["parameters"]
if ( if (
parameters is not None parameters is not None
and parameters.best_of is not None and parameters.best_of is not None
......
...@@ -25,6 +25,7 @@ from text_generation.types import ( ...@@ -25,6 +25,7 @@ from text_generation.types import (
Grammar, Grammar,
ChatComplete, ChatComplete,
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletionComplete,
) )
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
...@@ -42,11 +43,16 @@ class ResponseComparator(JSONSnapshotExtension): ...@@ -42,11 +43,16 @@ class ResponseComparator(JSONSnapshotExtension):
exclude=None, exclude=None,
matcher=None, matcher=None,
): ):
if isinstance(data, Response): if (
data = data.dict() isinstance(data, Response)
or isinstance(data, ChatComplete)
or isinstance(data, ChatCompletionChunk)
or isinstance(data, ChatCompletionComplete)
):
data = data.model_dump()
if isinstance(data, List): if isinstance(data, List):
data = [d.dict() for d in data] data = [d.model_dump() for d in data]
data = self._filter( data = self._filter(
data=data, depth=0, path=(), exclude=exclude, matcher=matcher data=data, depth=0, path=(), exclude=exclude, matcher=matcher
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
"usage": null "usage": null
} }
], ],
"created": 1708957015, "created": 1710795556,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
......
...@@ -8,24 +8,26 @@ ...@@ -8,24 +8,26 @@
"content": null, "content": null,
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": { "tool_calls": [
"function": { {
"description": null, "function": {
"name": "tools", "description": null,
"parameters": { "name": "tools",
"format": "celsius", "parameters": {
"location": "New York, NY", "format": "celsius",
"num_days": 14 "location": "New York, NY",
} "num_days": 14
}, }
"id": 0, },
"type": "function" "id": 0,
} "type": "function"
}
]
}, },
"usage": null "usage": null
} }
], ],
"created": 1709079417, "created": 1710795556,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
......
...@@ -8,24 +8,26 @@ ...@@ -8,24 +8,26 @@
"content": null, "content": null,
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": { "tool_calls": [
"function": { {
"description": null, "function": {
"name": "tools", "description": null,
"parameters": { "name": "tools",
"format": "celsius", "parameters": {
"location": "New York, NY", "format": "celsius",
"num_days": 14 "location": "New York, NY",
} "num_days": 14
}, }
"id": 0, },
"type": "function" "id": 0,
} "type": "function"
}
]
}, },
"usage": null "usage": null
} }
], ],
"created": 1709079492, "created": 1710795557,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
......
...@@ -8,23 +8,25 @@ ...@@ -8,23 +8,25 @@
"content": null, "content": null,
"name": null, "name": null,
"role": "assistant", "role": "assistant",
"tool_calls": { "tool_calls": [
"function": { {
"description": null, "function": {
"name": "tools", "description": null,
"parameters": { "name": "tools",
"format": "celsius", "parameters": {
"location": "New York, NY" "format": "celsius",
} "location": "New York, NY"
}, }
"id": 0, },
"type": "function" "id": 0,
} "type": "function"
}
]
}, },
"usage": null "usage": null
} }
], ],
"created": 1709079493, "created": 1710795557,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
......
...@@ -10,16 +10,16 @@ ...@@ -10,16 +10,16 @@
"name": null "name": null
}, },
"id": "", "id": "",
"index": 20, "index": 0,
"type": "function" "type": "function"
} }
}, },
"finish_reason": "eos_token", "finish_reason": "eos_token",
"index": 20, "index": 0,
"logprobs": null "logprobs": null
} }
], ],
"created": 1709087088, "created": 1710795499,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion", "object": "text_completion",
......
...@@ -119,19 +119,21 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna ...@@ -119,19 +119,21 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
], ],
) )
assert response.choices[0].message.content == None assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == { assert response.choices[0].message.tool_calls == [
"function": { {
"description": None, "function": {
"name": "tools", "description": None,
"parameters": { "name": "tools",
"format": "celsius", "parameters": {
"location": "New York, NY", "format": "celsius",
"num_days": 14, "location": "New York, NY",
"num_days": 14,
},
}, },
}, "id": 0,
"id": 0, "type": "function",
"type": "function", }
} ]
assert response == response_snapshot assert response == response_snapshot
...@@ -159,19 +161,21 @@ async def test_flash_llama_grammar_tools_auto( ...@@ -159,19 +161,21 @@ async def test_flash_llama_grammar_tools_auto(
], ],
) )
assert response.choices[0].message.content == None assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == { assert response.choices[0].message.tool_calls == [
"function": { {
"description": None, "function": {
"name": "tools", "description": None,
"parameters": { "name": "tools",
"format": "celsius", "parameters": {
"location": "New York, NY", "format": "celsius",
"num_days": 14, "location": "New York, NY",
"num_days": 14,
},
}, },
}, "id": 0,
"id": 0, "type": "function",
"type": "function", }
} ]
assert response == response_snapshot assert response == response_snapshot
...@@ -199,15 +203,17 @@ async def test_flash_llama_grammar_tools_choice( ...@@ -199,15 +203,17 @@ async def test_flash_llama_grammar_tools_choice(
], ],
) )
assert response.choices[0].message.content == None assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == { assert response.choices[0].message.tool_calls == [
"id": 0, {
"type": "function", "id": 0,
"function": { "type": "function",
"description": None, "function": {
"name": "tools", "description": None,
"parameters": {"format": "celsius", "location": "New York, NY"}, "name": "tools",
}, "parameters": {"format": "celsius", "location": "New York, NY"},
} },
}
]
assert response == response_snapshot assert response == response_snapshot
......
This diff is collapsed.
...@@ -5,6 +5,7 @@ description = "Text Generation Inference integration tests" ...@@ -5,6 +5,7 @@ description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"] authors = ["Nicolas Patry <nicolas@huggingface.co>"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
pydantic = "> 2, < 3"
python = ">=3.9,<3.13" python = ">=3.9,<3.13"
syrupy = "4.0.1" syrupy = "4.0.1"
text-generation = "^0.6.0" text-generation = "^0.6.0"
......
aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "3.13" aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13" aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13" async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13"
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13" attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13" certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
...@@ -17,14 +18,15 @@ iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13" ...@@ -17,14 +18,15 @@ iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13"
multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13" multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
pluggy==1.3.0 ; python_version >= "3.9" and python_version < "3.13" pluggy==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
pydantic==1.10.12 ; python_version >= "3.9" and python_version < "3.13" pydantic-core==2.16.3 ; python_version >= "3.9" and python_version < "3.13"
pydantic==2.6.4 ; python_version >= "3.9" and python_version < "3.13"
pytest-asyncio==0.21.1 ; python_version >= "3.9" and python_version < "3.13" pytest-asyncio==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
pytest==7.4.0 ; python_version >= "3.9" and python_version < "3.13" pytest==7.4.0 ; python_version >= "3.9" and python_version < "3.13"
pywin32==306 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" pywin32==306 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
syrupy==4.0.1 ; python_version >= "3.9" and python_version < "3.13" syrupy==4.0.1 ; python_version >= "3.9" and python_version < "3.13"
text-generation==0.6.0 ; python_version >= "3.9" and python_version < "3.13" text-generation==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11" tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
......
...@@ -433,7 +433,7 @@ impl ChatCompletion { ...@@ -433,7 +433,7 @@ impl ChatCompletion {
created: u64, created: u64,
details: Details, details: Details,
return_logprobs: bool, return_logprobs: bool,
tool_calls: Option<ToolCall>, tool_calls: Option<Vec<ToolCall>>,
) -> Self { ) -> Self {
Self { Self {
id: String::new(), id: String::new(),
...@@ -781,7 +781,7 @@ pub(crate) struct Message { ...@@ -781,7 +781,7 @@ pub(crate) struct Message {
#[schema(example = "\"David\"")] #[schema(example = "\"David\"")]
pub name: Option<String>, pub name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<ToolCall>, pub tool_calls: Option<Vec<ToolCall>>,
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]
......
...@@ -942,7 +942,7 @@ async fn chat_completions( ...@@ -942,7 +942,7 @@ async fn chat_completions(
) )
})?; })?;
let tool_call = Some(ToolCall { let tool_calls = vec![ToolCall {
id: 0, id: 0,
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
...@@ -963,8 +963,8 @@ async fn chat_completions( ...@@ -963,8 +963,8 @@ async fn chat_completions(
|f| Ok(f.clone()), |f| Ok(f.clone()),
)?, )?,
}, },
}); }];
(tool_call, None) (Some(tool_calls), None)
} else { } else {
(None, Some(generation.generated_text)) (None, Some(generation.generated_text))
}; };
......
...@@ -34,7 +34,7 @@ peft = { version = "^0.9.0", optional = true } ...@@ -34,7 +34,7 @@ peft = { version = "^0.9.0", optional = true }
torch = { version = "^2.1.1", optional = true } torch = { version = "^2.1.1", optional = true }
scipy = "^1.11.1" scipy = "^1.11.1"
pillow = "^10.0.0" pillow = "^10.0.0"
outlines= { version = "^0.0.27", optional = true } outlines= { version = "^0.0.36", optional = true }
[tool.poetry.extras] [tool.poetry.extras]
torch = ["torch"] torch = ["torch"]
......
...@@ -6,7 +6,7 @@ from typing import Dict, Union ...@@ -6,7 +6,7 @@ from typing import Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object from outlines.fsm.json_schema import build_regex_from_schema
from functools import lru_cache from functools import lru_cache
from typing import List, Optional, DefaultDict from typing import List, Optional, DefaultDict
import time import time
...@@ -512,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor): ...@@ -512,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor):
def _cached_compile_fsm(grammar_type, schema, tokenizer): def _cached_compile_fsm(grammar_type, schema, tokenizer):
start_time = time.time() start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema) schema = build_regex_from_schema(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer) fsm = RegexFSM(schema, tokenizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment