Unverified Commit de6cb15f authored by drbh's avatar drbh Committed by GitHub
Browse files

fix: improve tool type, bump pydantic and outlines (#1650)

This PR resolves a couple 

- [X] adjusts the tool response to align with openai's tools response
type
- [X] bumps pydantic to `2.6.4` in all apps (resolves dependency issue
when running tests)
- [X] bump `outlines` version and fix import for new name
parent 4f09c80c
......@@ -377,6 +377,12 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfg_aliases"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e"
[[package]]
name = "clap"
version = "4.5.1"
......@@ -545,7 +551,7 @@ version = "3.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b467862cc8610ca6fc9a1532d7777cee0804e678ab45410897b9396495994a0b"
dependencies = [
"nix",
"nix 0.27.1",
"windows-sys 0.52.0",
]
......@@ -1613,6 +1619,18 @@ dependencies = [
"libc",
]
[[package]]
name = "nix"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4"
dependencies = [
"bitflags 2.4.2",
"cfg-if",
"cfg_aliases",
"libc",
]
[[package]]
name = "nohash-hasher"
version = "0.2.0"
......@@ -3000,7 +3018,7 @@ dependencies = [
"clap",
"ctrlc",
"float_eq",
"nix",
"nix 0.28.0",
"reqwest",
"serde",
"serde_json",
......
This diff is collapsed.
......@@ -12,7 +12,7 @@ repository = "https://github.com/huggingface/text-generation-inference"
[tool.poetry.dependencies]
python = "^3.7"
pydantic = "> 1.10, < 3"
pydantic = "> 2, < 3"
aiohttp = "^3.8"
huggingface-hub = ">= 0.12, < 1.0"
......
from enum import Enum
from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator
from typing import Optional, List, Union, Any
from text_generation.errors import ValidationError
......@@ -32,7 +32,7 @@ class Message(BaseModel):
# Role of the message sender
role: str
# Content of the message
content: Optional[str]
content: Optional[str] = None
# Optional name of the message sender
name: Optional[str] = None
# Tool calls associated with the chat completion
......@@ -56,7 +56,7 @@ class ChatCompletionComplete(BaseModel):
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Any
usage: Optional[Any] = None
class Function(BaseModel):
......@@ -73,7 +73,7 @@ class ChoiceDeltaToolCall(BaseModel):
class ChoiceDelta(BaseModel):
role: str
content: Optional[str]
content: Optional[str] = None
tool_calls: Optional[ChoiceDeltaToolCall]
......@@ -176,74 +176,74 @@ class Parameters(BaseModel):
# grammar to use for generation
grammar: Optional[Grammar] = None
@validator("best_of")
@field_validator("best_of")
def valid_best_of(cls, field_value, values):
if field_value is not None:
if field_value <= 0:
raise ValidationError("`best_of` must be strictly positive")
if field_value > 1 and values["seed"] is not None:
if field_value > 1 and values.data["seed"] is not None:
raise ValidationError("`seed` must not be set when `best_of` is > 1")
sampling = (
values["do_sample"]
| (values["temperature"] is not None)
| (values["top_k"] is not None)
| (values["top_p"] is not None)
| (values["typical_p"] is not None)
values.data["do_sample"]
| (values.data["temperature"] is not None)
| (values.data["top_k"] is not None)
| (values.data["top_p"] is not None)
| (values.data["typical_p"] is not None)
)
if field_value > 1 and not sampling:
raise ValidationError("you must use sampling when `best_of` is > 1")
return field_value
@validator("repetition_penalty")
@field_validator("repetition_penalty")
def valid_repetition_penalty(cls, v):
if v is not None and v <= 0:
raise ValidationError("`repetition_penalty` must be strictly positive")
return v
@validator("seed")
@field_validator("seed")
def valid_seed(cls, v):
if v is not None and v < 0:
raise ValidationError("`seed` must be positive")
return v
@validator("temperature")
@field_validator("temperature")
def valid_temp(cls, v):
if v is not None and v <= 0:
raise ValidationError("`temperature` must be strictly positive")
return v
@validator("top_k")
@field_validator("top_k")
def valid_top_k(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_k` must be strictly positive")
return v
@validator("top_p")
@field_validator("top_p")
def valid_top_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`top_p` must be > 0.0 and < 1.0")
return v
@validator("truncate")
@field_validator("truncate")
def valid_truncate(cls, v):
if v is not None and v <= 0:
raise ValidationError("`truncate` must be strictly positive")
return v
@validator("typical_p")
@field_validator("typical_p")
def valid_typical_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v
@validator("top_n_tokens")
@field_validator("top_n_tokens")
def valid_top_n_tokens(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_n_tokens` must be strictly positive")
return v
@validator("grammar")
@field_validator("grammar")
def valid_grammar(cls, v):
if v is not None:
if v.type == GrammarType.Regex and not v.value:
......@@ -261,15 +261,15 @@ class Request(BaseModel):
# Whether to stream output tokens
stream: bool = False
@validator("inputs")
@field_validator("inputs")
def valid_input(cls, v):
if not v:
raise ValidationError("`inputs` cannot be empty")
return v
@validator("stream")
@field_validator("stream")
def valid_best_of_stream(cls, field_value, values):
parameters = values["parameters"]
parameters = values.data["parameters"]
if (
parameters is not None
and parameters.best_of is not None
......
......@@ -25,6 +25,7 @@ from text_generation.types import (
Grammar,
ChatComplete,
ChatCompletionChunk,
ChatCompletionComplete,
)
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
......@@ -42,11 +43,16 @@ class ResponseComparator(JSONSnapshotExtension):
exclude=None,
matcher=None,
):
if isinstance(data, Response):
data = data.dict()
if (
isinstance(data, Response)
or isinstance(data, ChatComplete)
or isinstance(data, ChatCompletionChunk)
or isinstance(data, ChatCompletionComplete)
):
data = data.model_dump()
if isinstance(data, List):
data = [d.dict() for d in data]
data = [d.model_dump() for d in data]
data = self._filter(
data=data, depth=0, path=(), exclude=exclude, matcher=matcher
......
......@@ -13,7 +13,7 @@
"usage": null
}
],
"created": 1708957015,
"created": 1710795556,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
......
......@@ -8,7 +8,8 @@
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
......@@ -21,11 +22,12 @@
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1709079417,
"created": 1710795556,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
......
......@@ -8,7 +8,8 @@
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
......@@ -21,11 +22,12 @@
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1709079492,
"created": 1710795557,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
......
......@@ -8,7 +8,8 @@
"content": null,
"name": null,
"role": "assistant",
"tool_calls": {
"tool_calls": [
{
"function": {
"description": null,
"name": "tools",
......@@ -20,11 +21,12 @@
"id": 0,
"type": "function"
}
]
},
"usage": null
}
],
"created": 1709079493,
"created": 1710795557,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
......
......@@ -10,16 +10,16 @@
"name": null
},
"id": "",
"index": 20,
"index": 0,
"type": "function"
}
},
"finish_reason": "eos_token",
"index": 20,
"index": 0,
"logprobs": null
}
],
"created": 1709087088,
"created": 1710795499,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
......
......@@ -119,7 +119,8 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == {
assert response.choices[0].message.tool_calls == [
{
"function": {
"description": None,
"name": "tools",
......@@ -132,6 +133,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
"id": 0,
"type": "function",
}
]
assert response == response_snapshot
......@@ -159,7 +161,8 @@ async def test_flash_llama_grammar_tools_auto(
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == {
assert response.choices[0].message.tool_calls == [
{
"function": {
"description": None,
"name": "tools",
......@@ -172,6 +175,7 @@ async def test_flash_llama_grammar_tools_auto(
"id": 0,
"type": "function",
}
]
assert response == response_snapshot
......@@ -199,7 +203,8 @@ async def test_flash_llama_grammar_tools_choice(
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.tool_calls == {
assert response.choices[0].message.tool_calls == [
{
"id": 0,
"type": "function",
"function": {
......@@ -208,6 +213,7 @@ async def test_flash_llama_grammar_tools_choice(
"parameters": {"format": "celsius", "location": "New York, NY"},
},
}
]
assert response == response_snapshot
......
This diff is collapsed.
......@@ -5,6 +5,7 @@ description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"]
[tool.poetry.dependencies]
pydantic = "> 2, < 3"
python = ">=3.9,<3.13"
syrupy = "4.0.1"
text-generation = "^0.6.0"
......
aiohttp==3.8.5 ; python_version >= "3.9" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.13"
attrs==23.1.0 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
......@@ -17,14 +18,15 @@ iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.13"
multidict==6.0.4 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
pluggy==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
pydantic==1.10.12 ; python_version >= "3.9" and python_version < "3.13"
pydantic-core==2.16.3 ; python_version >= "3.9" and python_version < "3.13"
pydantic==2.6.4 ; python_version >= "3.9" and python_version < "3.13"
pytest-asyncio==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
pytest==7.4.0 ; python_version >= "3.9" and python_version < "3.13"
pywin32==306 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
syrupy==4.0.1 ; python_version >= "3.9" and python_version < "3.13"
text-generation==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
text-generation==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
tomli==2.0.1 ; python_version >= "3.9" and python_version < "3.11"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
......
......@@ -433,7 +433,7 @@ impl ChatCompletion {
created: u64,
details: Details,
return_logprobs: bool,
tool_calls: Option<ToolCall>,
tool_calls: Option<Vec<ToolCall>>,
) -> Self {
Self {
id: String::new(),
......@@ -781,7 +781,7 @@ pub(crate) struct Message {
#[schema(example = "\"David\"")]
pub name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<ToolCall>,
pub tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Clone, Debug, Deserialize, ToSchema)]
......
......@@ -942,7 +942,7 @@ async fn chat_completions(
)
})?;
let tool_call = Some(ToolCall {
let tool_calls = vec![ToolCall {
id: 0,
r#type: "function".to_string(),
function: FunctionDefinition {
......@@ -963,8 +963,8 @@ async fn chat_completions(
|f| Ok(f.clone()),
)?,
},
});
(tool_call, None)
}];
(Some(tool_calls), None)
} else {
(None, Some(generation.generated_text))
};
......
......@@ -34,7 +34,7 @@ peft = { version = "^0.9.0", optional = true }
torch = { version = "^2.1.1", optional = true }
scipy = "^1.11.1"
pillow = "^10.0.0"
outlines= { version = "^0.0.27", optional = true }
outlines= { version = "^0.0.36", optional = true }
[tool.poetry.extras]
torch = ["torch"]
......
......@@ -6,7 +6,7 @@ from typing import Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object
from outlines.fsm.json_schema import build_regex_from_schema
from functools import lru_cache
from typing import List, Optional, DefaultDict
import time
......@@ -512,7 +512,7 @@ class GrammarLogitProcessor(LogitsProcessor):
def _cached_compile_fsm(grammar_type, schema, tokenizer):
start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
schema = build_regex_from_object(schema)
schema = build_regex_from_schema(schema)
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment