Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
fa8a8e05
Unverified
Commit
fa8a8e05
authored
Feb 21, 2024
by
OlivierDehaene
Committed by
GitHub
Feb 21, 2024
Browse files
fix(router): fix openapi and add jsonschema validation (#1578)
parent
c9f4c1af
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
297 additions
and
107 deletions
+297
-107
.github/workflows/tests.yaml
.github/workflows/tests.yaml
+4
-0
Cargo.lock
Cargo.lock
+157
-0
docs/openapi.json
docs/openapi.json
+51
-0
integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json
...s__/test_grammar_llama/test_flash_llama_grammar_json.json
+46
-46
integration-tests/models/test_flash_awq.py
integration-tests/models/test_flash_awq.py
+0
-3
integration-tests/models/test_flash_awq_sharded.py
integration-tests/models/test_flash_awq_sharded.py
+0
-2
integration-tests/models/test_flash_medusa.py
integration-tests/models/test_flash_medusa.py
+0
-3
integration-tests/models/test_flash_mistral.py
integration-tests/models/test_flash_mistral.py
+0
-3
integration-tests/models/test_flash_phi.py
integration-tests/models/test_flash_phi.py
+0
-3
integration-tests/models/test_flash_starcoder_gptq.py
integration-tests/models/test_flash_starcoder_gptq.py
+0
-3
integration-tests/models/test_grammar_llama.py
integration-tests/models/test_grammar_llama.py
+1
-6
integration-tests/models/test_mamba.py
integration-tests/models/test_mamba.py
+0
-3
router/Cargo.toml
router/Cargo.toml
+1
-0
router/src/lib.rs
router/src/lib.rs
+7
-30
router/src/server.rs
router/src/server.rs
+1
-0
router/src/validation.rs
router/src/validation.rs
+27
-2
server/text_generation_server/utils/tokens.py
server/text_generation_server/utils/tokens.py
+2
-3
No files found.
.github/workflows/tests.yaml
View file @
fa8a8e05
...
...
@@ -41,6 +41,10 @@ jobs:
components
:
rustfmt, clippy
-
name
:
Install Protoc
uses
:
arduino/setup-protoc@v1
-
name
:
Clean unused files
run
:
|
sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
-
name
:
Install sccache
run
:
|
curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache
...
...
Cargo.lock
View file @
fa8a8e05
...
...
@@ -24,7 +24,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff"
dependencies = [
"cfg-if",
"getrandom",
"once_cell",
"serde",
"version_check",
"zerocopy",
]
...
...
@@ -265,6 +267,21 @@ version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "bit-set"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
dependencies = [
"bit-vec",
]
[[package]]
name = "bit-vec"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]]
name = "bitflags"
version = "1.3.2"
...
...
@@ -716,6 +733,16 @@ dependencies = [
"cc",
]
[[package]]
name = "fancy-regex"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2"
dependencies = [
"bit-set",
"regex",
]
[[package]]
name = "fastrand"
version = "2.0.1"
...
...
@@ -780,6 +807,16 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "fraction"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3027ae1df8d41b4bed2241c8fdad4acc1e7af60c8e17743534b545e77182d678"
dependencies = [
"lazy_static",
"num",
]
[[package]]
name = "futures"
version = "0.3.30"
...
...
@@ -895,8 +932,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi",
"wasm-bindgen",
]
[[package]]
...
...
@@ -1181,6 +1220,15 @@ version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
[[package]]
name = "iso8601"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "924e5d73ea28f59011fec52a0d12185d496a9b075d360657aed2a5707f701153"
dependencies = [
"nom",
]
[[package]]
name = "itertools"
version = "0.10.5"
...
...
@@ -1223,6 +1271,36 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "jsonschema"
version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a071f4f7efc9a9118dfb627a0a94ef247986e1ab8606a4c806ae2b3aa3b6978"
dependencies = [
"ahash",
"anyhow",
"base64 0.21.7",
"bytecount",
"clap",
"fancy-regex",
"fraction",
"getrandom",
"iso8601",
"itoa",
"memchr",
"num-cmp",
"once_cell",
"parking_lot",
"percent-encoding",
"regex",
"reqwest",
"serde",
"serde_json",
"time",
"url",
"uuid",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
...
...
@@ -1574,12 +1652,84 @@ dependencies = [
"winapi",
]
[[package]]
name = "num"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
dependencies = [
"num-bigint",
"num-complex",
"num-integer",
"num-iter",
"num-rational",
"num-traits",
]
[[package]]
name = "num-bigint"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-cmp"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa"
[[package]]
name = "num-complex"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6"
dependencies = [
"num-traits",
]
[[package]]
name = "num-conv"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-iter"
version = "0.1.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
dependencies = [
"autocfg",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.18"
...
...
@@ -2874,6 +3024,7 @@ dependencies = [
"futures-util",
"hf-hub",
"init-tracing-opentelemetry",
"jsonschema",
"metrics",
"metrics-exporter-prometheus",
"minijinja",
...
...
@@ -3530,6 +3681,12 @@ dependencies = [
"zip",
]
[[package]]
name = "uuid"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a"
[[package]]
name = "valuable"
version = "0.1.0"
...
...
docs/openapi.json
View file @
fa8a8e05
...
...
@@ -1022,6 +1022,57 @@
}
}
},
"GrammarType"
:
{
"oneOf"
:
[
{
"type"
:
"object"
,
"required"
:
[
"type"
,
"value"
],
"properties"
:
{
"type"
:
{
"type"
:
"string"
,
"enum"
:
[
"json"
]
},
"value"
:
{
"type"
:
"string"
,
"description"
:
"A string that represents a [JSON Schema](https://json-schema.org/).
\n\n
JSON Schema is a declarative language that allows to annotate JSON documents
\n
with types and descriptions."
,
"example"
:
{
"properties"
:
{
"location"
:
{
"type"
:
"string"
}
}
}
}
}
},
{
"type"
:
"object"
,
"required"
:
[
"type"
,
"value"
],
"properties"
:
{
"type"
:
{
"type"
:
"string"
,
"enum"
:
[
"regex"
]
},
"value"
:
{
"type"
:
"string"
}
}
}
],
"discriminator"
:
{
"propertyName"
:
"type"
}
},
"Info"
:
{
"type"
:
"object"
,
"required"
:
[
...
...
integration-tests/models/__snapshots__/test_grammar_llama/test_flash_llama_grammar_json.json
View file @
fa8a8e05
...
...
@@ -136,128 +136,128 @@
"text"
:
"
\"
,
\"
"
},
{
"id"
:
4230
,
"logprob"
:
-0.0
20492554
,
"id"
:
29882
,
"logprob"
:
-0.0
8862305
,
"special"
:
false
,
"text"
:
"
last
"
"text"
:
"
h
"
},
{
"id"
:
11
70
,
"logprob"
:
-0.
0013818741
,
"id"
:
7
11
,
"logprob"
:
-0.
66259766
,
"special"
:
false
,
"text"
:
"
Name
"
"text"
:
"
ob
"
},
{
"id"
:
4710
,
"logprob"
:
-
0.0067749023
,
"id"
:
1609
,
"logprob"
:
-
5.51939e-05
,
"special"
:
false
,
"text"
:
"
\"
:
\"
"
"text"
:
"
by
"
},
{
"id"
:
2995
0
,
"logprob"
:
-0.
11578369
,
"id"
:
471
0
,
"logprob"
:
-0.
23120117
,
"special"
:
false
,
"text"
:
"
H
"
"text"
:
"
\"
:
\"
"
},
{
"id"
:
14339
,
"logprob"
:
-
0.004131317
,
"id"
:
29911
,
"logprob"
:
-
2.3730469
,
"special"
:
false
,
"text"
:
"
olt
"
"text"
:
"
T
"
},
{
"id"
:
29920
,
"logprob"
:
-0.0
033359528
,
"id"
:
11003
,
"logprob"
:
-0.0
32104492
,
"special"
:
false
,
"text"
:
"
z
"
"text"
:
"
rees
"
},
{
"id"
:
3284
,
"logprob"
:
-0.2
0471191
,
"logprob"
:
-0.2
2021484
,
"special"
:
false
,
"text"
:
"
\"
,
\"
"
},
{
"id"
:
29882
,
"logprob"
:
-0.0
069274902
,
"id"
:
4230
,
"logprob"
:
-0.0
6726074
,
"special"
:
false
,
"text"
:
"
h
"
"text"
:
"
last
"
},
{
"id"
:
20838
,
"logprob"
:
-0.
19580078
,
"id"
:
1170
,
"logprob"
:
-0.
003501892
,
"special"
:
false
,
"text"
:
"
obb
"
"text"
:
"
Name
"
},
{
"id"
:
29891
,
"logprob"
:
-
2.2649765e-0
6
,
"id"
:
4710
,
"logprob"
:
-
0.004566192
6
,
"special"
:
false
,
"text"
:
"
y
"
"text"
:
"
\"
:
\"
"
},
{
"id"
:
471
0
,
"logprob"
:
-0.
32080
07
8
,
"id"
:
2995
0
,
"logprob"
:
-0.
125122
07
,
"special"
:
false
,
"text"
:
"
\"
:
\"
"
"text"
:
"
H
"
},
{
"id"
:
29911
,
"logprob"
:
-
2.1035156
,
"id"
:
14339
,
"logprob"
:
-
0.009552002
,
"special"
:
false
,
"text"
:
"
T
"
"text"
:
"
olt
"
},
{
"id"
:
11003
,
"logprob"
:
-0.0
20767212
,
"id"
:
29920
,
"logprob"
:
-0.0
0042438507
,
"special"
:
false
,
"text"
:
"
rees
"
"text"
:
"
z
"
},
{
"id"
:
3284
,
"logprob"
:
-0.
6010742
,
"logprob"
:
-0.
11651611
,
"special"
:
false
,
"text"
:
"
\"
,
\"
"
},
{
"id"
:
29876
,
"logprob"
:
-0.
57666016
,
"logprob"
:
-0.
29736328
,
"special"
:
false
,
"text"
:
"n"
},
{
"id"
:
398
,
"logprob"
:
-0.00
61073303
,
"logprob"
:
-0.00
3030777
,
"special"
:
false
,
"text"
:
"um"
},
{
"id"
:
29907
,
"logprob"
:
-0.
45703125
,
"logprob"
:
-0.
3774414
,
"special"
:
false
,
"text"
:
"C"
},
{
"id"
:
1446
,
"logprob"
:
-0.000
2872944
,
"logprob"
:
-0.000
3130436
,
"special"
:
false
,
"text"
:
"ats"
},
{
"id"
:
1115
,
"logprob"
:
-0.0021
018982
,
"logprob"
:
-0.0021
514893
,
"special"
:
false
,
"text"
:
"
\"
:"
},
{
"id"
:
29906
,
"logprob"
:
-0.0899
6582
,
"logprob"
:
-0.0
71
899
414
,
"special"
:
false
,
"text"
:
"2"
},
{
"id"
:
29913
,
"logprob"
:
-0.0
21697998
,
"logprob"
:
-0.0
18997192
,
"special"
:
false
,
"text"
:
"}"
},
...
...
@@ -270,5 +270,5 @@
],
"top_tokens"
:
null
},
"generated_text"
:
"{
\"
firstName
\"
:
\"
David
\"
,
\"
lastName
\"
:
\"
Holtz
\"
,
\"
hobby
\"
:
\"
Trees
\"
,
\"
numCats
\"
:2}"
"generated_text"
:
"{
\"
firstName
\"
:
\"
David
\"
,
\"
hobby
\"
:
\"
Trees
\"
,
\"
lastName
\"
:
\"
Holtz
\"
,
\"
numCats
\"
:2}"
}
integration-tests/models/test_flash_awq.py
View file @
fa8a8e05
...
...
@@ -18,7 +18,6 @@ async def flash_llama_awq(flash_llama_awq_handle):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_awq
(
flash_llama_awq
,
response_snapshot
):
response
=
await
flash_llama_awq
.
generate
(
"What is Deep Learning?"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
...
...
@@ -33,7 +32,6 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_awq_all_params
(
flash_llama_awq
,
response_snapshot
):
response
=
await
flash_llama_awq
.
generate
(
"What is Deep Learning?"
,
...
...
@@ -55,7 +53,6 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_awq_load
(
flash_llama_awq
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
flash_llama_awq
,
"What is Deep Learning?"
,
max_new_tokens
=
10
,
n
=
4
...
...
integration-tests/models/test_flash_awq_sharded.py
View file @
fa8a8e05
...
...
@@ -18,7 +18,6 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_awq_sharded
(
flash_llama_awq_sharded
,
response_snapshot
):
response
=
await
flash_llama_awq_sharded
.
generate
(
"What is Deep Learning?"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
...
...
@@ -33,7 +32,6 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_awq_load_sharded
(
flash_llama_awq_sharded
,
generate_load
,
response_snapshot
):
...
...
integration-tests/models/test_flash_medusa.py
View file @
fa8a8e05
...
...
@@ -14,7 +14,6 @@ async def flash_medusa(flash_medusa_handle):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_medusa_simple
(
flash_medusa
,
response_snapshot
):
response
=
await
flash_medusa
.
generate
(
"What is Deep Learning?"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
...
...
@@ -25,7 +24,6 @@ async def test_flash_medusa_simple(flash_medusa, response_snapshot):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_medusa_all_params
(
flash_medusa
,
response_snapshot
):
response
=
await
flash_medusa
.
generate
(
"What is Deep Learning?"
,
...
...
@@ -48,7 +46,6 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_medusa_load
(
flash_medusa
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
flash_medusa
,
"What is Deep Learning?"
,
max_new_tokens
=
10
,
n
=
4
...
...
integration-tests/models/test_flash_mistral.py
View file @
fa8a8e05
...
...
@@ -14,7 +14,6 @@ async def flash_mistral(flash_mistral_handle):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_mistral
(
flash_mistral
,
response_snapshot
):
response
=
await
flash_mistral
.
generate
(
"Test request"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
...
...
@@ -26,7 +25,6 @@ async def test_flash_mistral(flash_mistral, response_snapshot):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_mistral_all_params
(
flash_mistral
,
response_snapshot
):
response
=
await
flash_mistral
.
generate
(
"Test request"
,
...
...
@@ -49,7 +47,6 @@ async def test_flash_mistral_all_params(flash_mistral, response_snapshot):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_mistral_load
(
flash_mistral
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
flash_mistral
,
"Test request"
,
max_new_tokens
=
10
,
n
=
4
...
...
integration-tests/models/test_flash_phi.py
View file @
fa8a8e05
...
...
@@ -14,7 +14,6 @@ async def flash_phi(flash_phi_handle):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_phi
(
flash_phi
,
response_snapshot
):
response
=
await
flash_phi
.
generate
(
"Test request"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
...
...
@@ -26,7 +25,6 @@ async def test_flash_phi(flash_phi, response_snapshot):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_phi_all_params
(
flash_phi
,
response_snapshot
):
response
=
await
flash_phi
.
generate
(
"Test request"
,
...
...
@@ -50,7 +48,6 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_phi_load
(
flash_phi
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
flash_phi
,
"Test request"
,
max_new_tokens
=
10
,
n
=
4
)
...
...
integration-tests/models/test_flash_starcoder_gptq.py
View file @
fa8a8e05
...
...
@@ -14,7 +14,6 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder_gptq
(
flash_starcoder_gptq
,
generous_response_snapshot
):
response
=
await
flash_starcoder_gptq
.
generate
(
"def geometric_mean(L: List[float]):"
,
...
...
@@ -26,7 +25,6 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder_gptq_default_params
(
flash_starcoder_gptq
,
generous_response_snapshot
):
...
...
@@ -43,7 +41,6 @@ async def test_flash_starcoder_gptq_default_params(
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder_gptq_load
(
flash_starcoder_gptq
,
generate_load
,
generous_response_snapshot
):
...
...
integration-tests/models/test_grammar_llama.py
View file @
fa8a8e05
...
...
@@ -19,7 +19,6 @@ async def flash_llama_grammar(flash_llama_grammar_handle):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_grammar
(
flash_llama_grammar
,
response_snapshot
):
response
=
await
flash_llama_grammar
.
generate
(
"Test request"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
...
...
@@ -30,7 +29,6 @@ async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_grammar_regex
(
flash_llama_grammar
,
response_snapshot
):
response
=
await
flash_llama_grammar
.
generate
(
"Whats Googles DNS"
,
...
...
@@ -49,7 +47,6 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_grammar_json
(
flash_llama_grammar
,
response_snapshot
):
response
=
await
flash_llama_grammar
.
generate
(
"info: david holtz like trees and has two cats. "
,
...
...
@@ -92,13 +89,12 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
assert
response
.
details
.
generated_tokens
==
30
assert
(
response
.
generated_text
==
'{"firstName":"David","lastName":"Holtz
","hobby":"Trees
","numCats":2}'
==
'{"firstName":"David","
hobby":"Trees","
lastName":"Holtz","numCats":2}'
)
assert
response
==
response_snapshot
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_grammar_load
(
flash_llama_grammar
,
generate_load
,
response_snapshot
):
...
...
@@ -130,7 +126,6 @@ async def test_flash_llama_grammar_load(
# this is the same as the above test, but only fires off a single request
# this is only to ensure that the parallel and single inference produce the same result
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_grammar_single_load_instance
(
flash_llama_grammar
,
generate_load
,
response_snapshot
):
...
...
integration-tests/models/test_mamba.py
View file @
fa8a8e05
...
...
@@ -14,7 +14,6 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_mamba
(
fused_kernel_mamba
,
response_snapshot
):
response
=
await
fused_kernel_mamba
.
generate
(
"What is Deep Learning?"
,
max_new_tokens
=
10
...
...
@@ -26,7 +25,6 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_mamba_all_params
(
fused_kernel_mamba
,
response_snapshot
):
response
=
await
fused_kernel_mamba
.
generate
(
"blue, red, yellow, "
,
...
...
@@ -53,7 +51,6 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_mamba_load
(
fused_kernel_mamba
,
generate_load
,
generous_response_snapshot
):
...
...
router/Cargo.toml
View file @
fa8a8e05
...
...
@@ -22,6 +22,7 @@ text-generation-client = { path = "client" }
clap
=
{
version
=
"4.4.5"
,
features
=
[
"derive"
,
"env"
]
}
futures
=
"0.3.28"
hf-hub
=
{
version
=
"0.3.0"
,
features
=
["tokio"]
}
jsonschema
=
{
version
=
"0.17.1"
,
features
=
["draft202012"]
}
metrics
=
"0.21.1"
metrics-exporter-prometheus
=
{
version
=
"0.12.1"
,
features
=
[]
}
nohash-hasher
=
"0.2.0"
...
...
router/src/lib.rs
View file @
fa8a8e05
...
...
@@ -64,39 +64,16 @@ impl HubTokenizerConfig {
}
}
mod
json_object_or_string_to_string
{
use
serde
::{
Deserialize
,
Deserializer
};
use
serde_json
::
Value
;
// A custom deserializer that treats both strings and objects as strings.
// This provides flexibility with input formats for the 'grammar' field.
pub
fn
deserialize
<
'de
,
D
>
(
deserializer
:
D
)
->
Result
<
String
,
D
::
Error
>
where
D
:
Deserializer
<
'de
>
,
{
let
value
=
Value
::
deserialize
(
deserializer
)
?
;
match
value
{
Value
::
String
(
s
)
=>
Ok
(
s
),
// Safely handle serialization and return an error if it fails
Value
::
Object
(
o
)
=>
{
serde_json
::
to_string
(
&
o
)
.map_err
(|
e
|
serde
::
de
::
Error
::
custom
(
e
.to_string
()))
}
_
=>
Err
(
serde
::
de
::
Error
::
custom
(
"expected string or object for grammar"
,
)),
}
}
}
#[derive(Clone,
Debug,
Deserialize,
ToSchema)]
#[serde(tag
=
"type"
,
content
=
"value"
)]
pub
(
crate
)
enum
GrammarType
{
#[serde(
rename
=
"json"
,
deserialize_with
=
"json_object_or_string_to_string::deserialize"
)]
Json
(
String
),
/// A string that represents a [JSON Schema](https://json-schema.org/).
///
/// JSON Schema is a declarative language that allows to annotate JSON documents
/// with types and descriptions.
#[serde(rename
=
"json"
)]
#[schema(example
=
json
!
(
{
"properties"
:
{
"location"
:
{
"type"
:
"string"
}}}
))]
Json
(
serde_json
::
Value
),
#[serde(rename
=
"regex"
)]
Regex
(
String
),
}
...
...
router/src/server.rs
View file @
fa8a8e05
...
...
@@ -893,6 +893,7 @@ pub async fn run(
Info,
CompatGenerateRequest,
GenerateRequest,
GrammarType,
ChatRequest,
Message,
ChatCompletionChoice,
...
...
router/src/validation.rs
View file @
fa8a8e05
/// Payload validation logic
use
crate
::
validation
::
ValidationError
::{
BestOfSampling
,
BestOfSeed
,
EmptyInput
};
use
crate
::{
GenerateParameters
,
GenerateRequest
,
GrammarType
};
use
jsonschema
::{
Draft
,
JSONSchema
};
use
rand
::{
thread_rng
,
Rng
};
use
serde_json
::
Value
;
use
text_generation_client
::{
GrammarType
as
ProtoGrammarType
,
NextTokenChooserParameters
,
StoppingCriteriaParameters
,
};
...
...
@@ -313,8 +315,29 @@ impl Validation {
return
Err
(
ValidationError
::
Grammar
);
}
match
grammar
{
// currently both are handled the same way since compilation is done in Python
GrammarType
::
Json
(
json
)
=>
(
json
,
ProtoGrammarType
::
Json
.into
()),
GrammarType
::
Json
(
json
)
=>
{
let
json
=
match
json
{
// if value is a string, we need to parse it again to make sure its
// a valid json
Value
::
String
(
s
)
=>
serde_json
::
from_str
(
&
s
)
.map_err
(|
e
|
ValidationError
::
InvalidGrammar
(
e
.to_string
())),
Value
::
Object
(
_
)
=>
Ok
(
json
),
_
=>
Err
(
ValidationError
::
Grammar
),
}
?
;
// Check if the json is a valid JSONSchema
JSONSchema
::
options
()
.with_draft
(
Draft
::
Draft202012
)
.compile
(
&
json
)
.map_err
(|
e
|
ValidationError
::
InvalidGrammar
(
e
.to_string
()))
?
;
(
// Serialize json to string
serde_json
::
to_string
(
&
json
)
.map_err
(|
e
|
ValidationError
::
InvalidGrammar
(
e
.to_string
()))
?
,
ProtoGrammarType
::
Json
.into
(),
)
}
GrammarType
::
Regex
(
regex
)
=>
(
regex
,
ProtoGrammarType
::
Regex
.into
()),
}
}
...
...
@@ -486,6 +509,8 @@ pub enum ValidationError {
Tokenizer
(
String
),
#[error(
"grammar is not supported"
)]
Grammar
,
#[error(
"grammar is not valid: {0}"
)]
InvalidGrammar
(
String
),
}
#[cfg(test)]
...
...
server/text_generation_server/utils/tokens.py
View file @
fa8a8e05
...
...
@@ -328,7 +328,6 @@ class HeterogeneousNextTokenChooser:
scores
=
scores
.
view
(
B
,
S
,
-
1
)
next_ids
=
torch
.
zeros
((
B
,
S
),
device
=
scores
.
device
,
dtype
=
torch
.
long
)
mask
=
torch
.
full
((
scores
.
shape
[
-
1
],),
-
math
.
inf
,
device
=
self
.
device
)
for
j
in
range
(
S
):
_scores
=
scores
[:,
j
]
...
...
@@ -338,10 +337,10 @@ class HeterogeneousNextTokenChooser:
_scores
=
self
.
repetition_processor
(
input_ids
,
_scores
)
if
self
.
frequency_processor
is
not
None
:
_scores
=
self
.
frequency_processor
(
input_ids
,
_scores
)
for
warper
in
self
.
warpers
:
_scores
=
warper
(
input_ids
,
_scores
)
if
self
.
grammar_processor
is
not
None
:
_scores
=
self
.
grammar_processor
(
_scores
,
self
.
fsm_grammar_states
)
for
warper
in
self
.
warpers
:
_scores
=
warper
(
input_ids
,
_scores
)
_next_ids
=
self
.
choice
(
_scores
)
scores
[:,
j
]
=
_scores
next_ids
[:,
j
]
=
_next_ids
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment