Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
2475aede
Unverified
Commit
2475aede
authored
Apr 18, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 18, 2023
Browse files
feat(router): add info route (#196)
close #125
parent
b927244e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
409 additions
and
117 deletions
+409
-117
Cargo.lock
Cargo.lock
+23
-0
docs/openapi.json
docs/openapi.json
+188
-20
launcher/src/main.rs
launcher/src/main.rs
+6
-0
router/Cargo.toml
router/Cargo.toml
+4
-1
router/build.rs
router/build.rs
+7
-0
router/src/lib.rs
router/src/lib.rs
+23
-0
router/src/main.rs
router/src/main.rs
+42
-24
router/src/server.rs
router/src/server.rs
+116
-72
No files found.
Cargo.lock
View file @
2475aede
...
@@ -2430,6 +2430,7 @@ dependencies = [
...
@@ -2430,6 +2430,7 @@ dependencies = [
"tracing-subscriber",
"tracing-subscriber",
"utoipa",
"utoipa",
"utoipa-swagger-ui",
"utoipa-swagger-ui",
"vergen",
]
]
[[package]]
[[package]]
...
@@ -2468,8 +2469,10 @@ version = "0.3.20"
...
@@ -2468,8 +2469,10 @@ version = "0.3.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890"
checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890"
dependencies = [
dependencies = [
"itoa",
"serde",
"serde",
"time-core",
"time-core",
"time-macros",
]
]
[[package]]
[[package]]
...
@@ -2478,6 +2481,15 @@ version = "0.1.0"
...
@@ -2478,6 +2481,15 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd"
checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd"
[[package]]
name = "time-macros"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd80a657e71da814b8e5d60d3374fc6d35045062245d80224748ae522dd76f36"
dependencies = [
"time-core",
]
[[package]]
[[package]]
name = "tinyvec"
name = "tinyvec"
version = "1.6.0"
version = "1.6.0"
...
@@ -2966,6 +2978,17 @@ version = "0.2.15"
...
@@ -2966,6 +2978,17 @@ version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "vergen"
version = "8.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1b86a8af1dedf089b1c78338678e4c7492b6045649042d94faf19690499d236"
dependencies = [
"anyhow",
"rustversion",
"time",
]
[[package]]
[[package]]
name = "version_check"
name = "version_check"
version = "0.9.4"
version = "0.9.4"
...
...
docs/openapi.json
View file @
2475aede
...
@@ -4,8 +4,7 @@
...
@@ -4,8 +4,7 @@
"title"
:
"Text Generation Inference"
,
"title"
:
"Text Generation Inference"
,
"description"
:
"Text Generation Webserver"
,
"description"
:
"Text Generation Webserver"
,
"contact"
:
{
"contact"
:
{
"name"
:
"Olivier Dehaene"
,
"name"
:
"Olivier Dehaene"
"email"
:
"olivier@huggingface.co"
},
},
"license"
:
{
"license"
:
{
"name"
:
"Apache 2.0"
,
"name"
:
"Apache 2.0"
,
...
@@ -14,6 +13,83 @@
...
@@ -14,6 +13,83 @@
"version"
:
"0.5.0"
"version"
:
"0.5.0"
},
},
"paths"
:
{
"paths"
:
{
"/"
:
{
"post"
:
{
"tags"
:
[
"Text Generation Inference"
],
"summary"
:
"Generate tokens if `stream == false` or a stream of token if `stream == true`"
,
"description"
:
"Generate tokens if `stream == false` or a stream of token if `stream == true`"
,
"operationId"
:
"compat_generate"
,
"requestBody"
:
{
"content"
:
{
"application/json"
:
{
"schema"
:
{
"$ref"
:
"#/components/schemas/CompatGenerateRequest"
}
}
},
"required"
:
true
},
"responses"
:
{
"200"
:
{
"description"
:
"See /generate or /generate_stream"
},
"422"
:
{
"description"
:
"Input validation error"
,
"content"
:
{
"application/json"
:
{
"schema"
:
{
"$ref"
:
"#/components/schemas/ErrorResponse"
},
"example"
:
{
"error"
:
"Input validation error"
}
}
}
},
"424"
:
{
"description"
:
"Generation Error"
,
"content"
:
{
"application/json"
:
{
"schema"
:
{
"$ref"
:
"#/components/schemas/ErrorResponse"
},
"example"
:
{
"error"
:
"Request failed during generation"
}
}
}
},
"429"
:
{
"description"
:
"Model is overloaded"
,
"content"
:
{
"application/json"
:
{
"schema"
:
{
"$ref"
:
"#/components/schemas/ErrorResponse"
},
"example"
:
{
"error"
:
"Model is overloaded"
}
}
}
},
"500"
:
{
"description"
:
"Incomplete generation"
,
"content"
:
{
"application/json"
:
{
"schema"
:
{
"$ref"
:
"#/components/schemas/ErrorResponse"
},
"example"
:
{
"error"
:
"Incomplete generation"
}
}
}
}
}
}
},
"/generate"
:
{
"/generate"
:
{
"post"
:
{
"post"
:
{
"tags"
:
[
"tags"
:
[
...
@@ -95,8 +171,7 @@
...
@@ -95,8 +171,7 @@
}
}
}
}
}
}
},
}
"deprecated"
:
false
}
}
},
},
"/generate_stream"
:
{
"/generate_stream"
:
{
...
@@ -180,8 +255,29 @@
...
@@ -180,8 +255,29 @@
}
}
}
}
}
}
},
}
"deprecated"
:
false
}
},
"/info"
:
{
"get"
:
{
"tags"
:
[
"Text Generation Inference"
],
"summary"
:
"Text Generation Inference endpoint info"
,
"description"
:
"Text Generation Inference endpoint info"
,
"operationId"
:
"get_model_info"
,
"responses"
:
{
"200"
:
{
"description"
:
"Served model info"
,
"content"
:
{
"application/json"
:
{
"schema"
:
{
"$ref"
:
"#/components/schemas/Info"
}
}
}
}
}
}
}
},
},
"/metrics"
:
{
"/metrics"
:
{
...
@@ -203,8 +299,7 @@
...
@@ -203,8 +299,7 @@
}
}
}
}
}
}
},
}
"deprecated"
:
false
}
}
}
}
},
},
...
@@ -230,7 +325,8 @@
...
@@ -230,7 +325,8 @@
"generated_tokens"
:
{
"generated_tokens"
:
{
"type"
:
"integer"
,
"type"
:
"integer"
,
"format"
:
"int32"
,
"format"
:
"int32"
,
"example"
:
1
"example"
:
1
,
"minimum"
:
0.0
},
},
"prefill"
:
{
"prefill"
:
{
"type"
:
"array"
,
"type"
:
"array"
,
...
@@ -242,7 +338,8 @@
...
@@ -242,7 +338,8 @@
"type"
:
"integer"
,
"type"
:
"integer"
,
"format"
:
"int64"
,
"format"
:
"int64"
,
"example"
:
42
,
"example"
:
42
,
"nullable"
:
true
"nullable"
:
true
,
"minimum"
:
0.0
},
},
"tokens"
:
{
"tokens"
:
{
"type"
:
"array"
,
"type"
:
"array"
,
...
@@ -252,6 +349,24 @@
...
@@ -252,6 +349,24 @@
}
}
}
}
},
},
"CompatGenerateRequest"
:
{
"type"
:
"object"
,
"required"
:
[
"inputs"
],
"properties"
:
{
"inputs"
:
{
"type"
:
"string"
,
"example"
:
"My name is Olivier and I"
},
"parameters"
:
{
"$ref"
:
"#/components/schemas/GenerateParameters"
},
"stream"
:
{
"type"
:
"boolean"
}
}
},
"Details"
:
{
"Details"
:
{
"type"
:
"object"
,
"type"
:
"object"
,
"required"
:
[
"required"
:
[
...
@@ -265,7 +380,8 @@
...
@@ -265,7 +380,8 @@
"type"
:
"array"
,
"type"
:
"array"
,
"items"
:
{
"items"
:
{
"$ref"
:
"#/components/schemas/BestOfSequence"
"$ref"
:
"#/components/schemas/BestOfSequence"
}
},
"nullable"
:
true
},
},
"finish_reason"
:
{
"finish_reason"
:
{
"$ref"
:
"#/components/schemas/FinishReason"
"$ref"
:
"#/components/schemas/FinishReason"
...
@@ -273,7 +389,8 @@
...
@@ -273,7 +389,8 @@
"generated_tokens"
:
{
"generated_tokens"
:
{
"type"
:
"integer"
,
"type"
:
"integer"
,
"format"
:
"int32"
,
"format"
:
"int32"
,
"example"
:
1
"example"
:
1
,
"minimum"
:
0.0
},
},
"prefill"
:
{
"prefill"
:
{
"type"
:
"array"
,
"type"
:
"array"
,
...
@@ -285,7 +402,8 @@
...
@@ -285,7 +402,8 @@
"type"
:
"integer"
,
"type"
:
"integer"
,
"format"
:
"int64"
,
"format"
:
"int64"
,
"example"
:
42
,
"example"
:
42
,
"nullable"
:
true
"nullable"
:
true
,
"minimum"
:
0.0
},
},
"tokens"
:
{
"tokens"
:
{
"type"
:
"array"
,
"type"
:
"array"
,
...
@@ -326,6 +444,7 @@
...
@@ -326,6 +444,7 @@
"default"
:
"null"
,
"default"
:
"null"
,
"example"
:
1
,
"example"
:
1
,
"nullable"
:
true
,
"nullable"
:
true
,
"minimum"
:
0.0
,
"exclusiveMinimum"
:
0.0
"exclusiveMinimum"
:
0.0
},
},
"details"
:
{
"details"
:
{
...
@@ -341,6 +460,7 @@
...
@@ -341,6 +460,7 @@
"type"
:
"integer"
,
"type"
:
"integer"
,
"format"
:
"int32"
,
"format"
:
"int32"
,
"default"
:
"20"
,
"default"
:
"20"
,
"minimum"
:
0.0
,
"exclusiveMaximum"
:
512.0
,
"exclusiveMaximum"
:
512.0
,
"exclusiveMinimum"
:
0.0
"exclusiveMinimum"
:
0.0
},
},
...
@@ -364,6 +484,7 @@
...
@@ -364,6 +484,7 @@
"default"
:
"null"
,
"default"
:
"null"
,
"example"
:
"null"
,
"example"
:
"null"
,
"nullable"
:
true
,
"nullable"
:
true
,
"minimum"
:
0.0
,
"exclusiveMinimum"
:
0.0
"exclusiveMinimum"
:
0.0
},
},
"stop"
:
{
"stop"
:
{
...
@@ -405,7 +526,8 @@
...
@@ -405,7 +526,8 @@
"type"
:
"integer"
,
"type"
:
"integer"
,
"default"
:
"null"
,
"default"
:
"null"
,
"example"
:
"null"
,
"example"
:
"null"
,
"nullable"
:
true
"nullable"
:
true
,
"minimum"
:
0.0
},
},
"typical_p"
:
{
"typical_p"
:
{
"type"
:
"number"
,
"type"
:
"number"
,
...
@@ -445,7 +567,12 @@
...
@@ -445,7 +567,12 @@
],
],
"properties"
:
{
"properties"
:
{
"details"
:
{
"details"
:
{
"$ref"
:
"#/components/schemas/Details"
"allOf"
:
[
{
"$ref"
:
"#/components/schemas/Details"
}
],
"nullable"
:
true
},
},
"generated_text"
:
{
"generated_text"
:
{
"type"
:
"string"
,
"type"
:
"string"
,
...
@@ -453,6 +580,38 @@
...
@@ -453,6 +580,38 @@
}
}
}
}
},
},
"Info"
:
{
"type"
:
"object"
,
"required"
:
[
"model_id"
,
"version"
],
"properties"
:
{
"model_id"
:
{
"type"
:
"string"
,
"example"
:
"bigscience/blomm-560m"
},
"model_pipeline_tag"
:
{
"type"
:
"string"
,
"example"
:
"text-generation"
,
"nullable"
:
true
},
"model_sha"
:
{
"type"
:
"string"
,
"example"
:
"e985a63cdc139290c5f700ff1929f0b5942cced2"
,
"nullable"
:
true
},
"sha"
:
{
"type"
:
"string"
,
"example"
:
"null"
,
"nullable"
:
true
},
"version"
:
{
"type"
:
"string"
,
"example"
:
"0.5.0"
}
}
},
"PrefillToken"
:
{
"PrefillToken"
:
{
"type"
:
"object"
,
"type"
:
"object"
,
"required"
:
[
"required"
:
[
...
@@ -464,7 +623,8 @@
...
@@ -464,7 +623,8 @@
"id"
:
{
"id"
:
{
"type"
:
"integer"
,
"type"
:
"integer"
,
"format"
:
"int32"
,
"format"
:
"int32"
,
"example"
:
0
"example"
:
0
,
"minimum"
:
0.0
},
},
"logprob"
:
{
"logprob"
:
{
"type"
:
"number"
,
"type"
:
"number"
,
...
@@ -491,13 +651,15 @@
...
@@ -491,13 +651,15 @@
"generated_tokens"
:
{
"generated_tokens"
:
{
"type"
:
"integer"
,
"type"
:
"integer"
,
"format"
:
"int32"
,
"format"
:
"int32"
,
"example"
:
1
"example"
:
1
,
"minimum"
:
0.0
},
},
"seed"
:
{
"seed"
:
{
"type"
:
"integer"
,
"type"
:
"integer"
,
"format"
:
"int64"
,
"format"
:
"int64"
,
"example"
:
42
,
"example"
:
42
,
"nullable"
:
true
"nullable"
:
true
,
"minimum"
:
0.0
}
}
}
}
},
},
...
@@ -508,7 +670,12 @@
...
@@ -508,7 +670,12 @@
],
],
"properties"
:
{
"properties"
:
{
"details"
:
{
"details"
:
{
"$ref"
:
"#/components/schemas/StreamDetails"
"allOf"
:
[
{
"$ref"
:
"#/components/schemas/StreamDetails"
}
],
"nullable"
:
true
},
},
"generated_text"
:
{
"generated_text"
:
{
"type"
:
"string"
,
"type"
:
"string"
,
...
@@ -533,7 +700,8 @@
...
@@ -533,7 +700,8 @@
"id"
:
{
"id"
:
{
"type"
:
"integer"
,
"type"
:
"integer"
,
"format"
:
"int32"
,
"format"
:
"int32"
,
"example"
:
0
"example"
:
0
,
"minimum"
:
0.0
},
},
"logprob"
:
{
"logprob"
:
{
"type"
:
"number"
,
"type"
:
"number"
,
...
...
launcher/src/main.rs
View file @
2475aede
...
@@ -392,6 +392,12 @@ fn main() -> ExitCode {
...
@@ -392,6 +392,12 @@ fn main() -> ExitCode {
model_id
,
model_id
,
];
];
// Model optional revision
if
let
Some
(
ref
revision
)
=
revision
{
argv
.push
(
"--revision"
.to_string
());
argv
.push
(
revision
.to_string
())
}
if
json_output
{
if
json_output
{
argv
.push
(
"--json-output"
.to_string
());
argv
.push
(
"--json-output"
.to_string
());
}
}
...
...
router/Cargo.toml
View file @
2475aede
...
@@ -4,6 +4,7 @@ version = "0.5.0"
...
@@ -4,6 +4,7 @@ version = "0.5.0"
edition
=
"2021"
edition
=
"2021"
authors
=
[
"Olivier Dehaene"
]
authors
=
[
"Olivier Dehaene"
]
description
=
"Text Generation Webserver"
description
=
"Text Generation Webserver"
build
=
"build.rs"
[lib]
[lib]
path
=
"src/lib.rs"
path
=
"src/lib.rs"
...
@@ -26,7 +27,7 @@ nohash-hasher = "0.2.0"
...
@@ -26,7 +27,7 @@ nohash-hasher = "0.2.0"
opentelemetry
=
{
version
=
"0.18.0"
,
features
=
["rt-tokio"]
}
opentelemetry
=
{
version
=
"0.18.0"
,
features
=
["rt-tokio"]
}
opentelemetry-otlp
=
"0.11.0"
opentelemetry-otlp
=
"0.11.0"
rand
=
"0.8.5"
rand
=
"0.8.5"
reqwest
=
{
version
=
"0.11.14"
,
features
=
[]
}
reqwest
=
{
version
=
"0.11.14"
,
features
=
[]
}
serde
=
"1.0.152"
serde
=
"1.0.152"
serde_json
=
"1.0.93"
serde_json
=
"1.0.93"
thiserror
=
"1.0.38"
thiserror
=
"1.0.38"
...
@@ -39,3 +40,5 @@ tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
...
@@ -39,3 +40,5 @@ tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
utoipa
=
{
version
=
"3.0.1"
,
features
=
["axum_extras"]
}
utoipa
=
{
version
=
"3.0.1"
,
features
=
["axum_extras"]
}
utoipa-swagger-ui
=
{
version
=
"3.0.2"
,
features
=
["axum"]
}
utoipa-swagger-ui
=
{
version
=
"3.0.2"
,
features
=
["axum"]
}
[build-dependencies]
vergen
=
{
version
=
"8.0.0"
,
features
=
[
"build"
,
"git"
,
"gitcl"
]
}
router/build.rs
0 → 100644
View file @
2475aede
use
std
::
error
::
Error
;
use
vergen
::
EmitBuilder
;
fn
main
()
->
Result
<
(),
Box
<
dyn
Error
>>
{
EmitBuilder
::
builder
()
.git_sha
(
false
)
.emit
()
?
;
Ok
(())
}
router/src/lib.rs
View file @
2475aede
...
@@ -10,6 +10,29 @@ use serde::{Deserialize, Serialize};
...
@@ -10,6 +10,29 @@ use serde::{Deserialize, Serialize};
use
utoipa
::
ToSchema
;
use
utoipa
::
ToSchema
;
use
validation
::
Validation
;
use
validation
::
Validation
;
/// Hub type
#[derive(Clone,
Debug,
Deserialize)]
pub
struct
ModelInfo
{
#[serde(rename(deserialize
=
"id"
))]
pub
model_id
:
String
,
pub
sha
:
Option
<
String
>
,
pub
pipeline_tag
:
Option
<
String
>
,
}
#[derive(Clone,
Debug,
Serialize,
ToSchema)]
pub
struct
Info
{
#[schema(example
=
"bigscience/blomm-560m"
)]
pub
model_id
:
String
,
#[schema(nullable
=
true
,
example
=
"e985a63cdc139290c5f700ff1929f0b5942cced2"
)]
pub
model_sha
:
Option
<
String
>
,
#[schema(nullable
=
true
,
example
=
"text-generation"
)]
pub
model_pipeline_tag
:
Option
<
String
>
,
#[schema(example
=
"0.5.0"
)]
pub
version
:
&
'static
str
,
#[schema(nullable
=
true
,
example
=
"null"
)]
pub
sha
:
Option
<&
'static
str
>
,
}
#[derive(Clone,
Debug,
Deserialize,
ToSchema)]
#[derive(Clone,
Debug,
Deserialize,
ToSchema)]
pub
(
crate
)
struct
GenerateParameters
{
pub
(
crate
)
struct
GenerateParameters
{
#[serde(default)]
#[serde(default)]
...
...
router/src/main.rs
View file @
2475aede
...
@@ -10,8 +10,8 @@ use opentelemetry_otlp::WithExportConfig;
...
@@ -10,8 +10,8 @@ use opentelemetry_otlp::WithExportConfig;
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
path
::
Path
;
use
std
::
path
::
Path
;
use
text_generation_client
::
ShardedClient
;
use
text_generation_client
::
ShardedClient
;
use
text_generation_router
::
server
;
use
text_generation_router
::
{
server
,
ModelInfo
}
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
{
FromPretrainedParameters
,
Tokenizer
}
;
use
tower_http
::
cors
::
AllowOrigin
;
use
tower_http
::
cors
::
AllowOrigin
;
use
tracing_subscriber
::
layer
::
SubscriberExt
;
use
tracing_subscriber
::
layer
::
SubscriberExt
;
use
tracing_subscriber
::
util
::
SubscriberInitExt
;
use
tracing_subscriber
::
util
::
SubscriberInitExt
;
...
@@ -41,6 +41,8 @@ struct Args {
...
@@ -41,6 +41,8 @@ struct Args {
master_shard_uds_path
:
String
,
master_shard_uds_path
:
String
,
#[clap(default_value
=
"bigscience/bloom"
,
long,
env)]
#[clap(default_value
=
"bigscience/bloom"
,
long,
env)]
tokenizer_name
:
String
,
tokenizer_name
:
String
,
#[clap(default_value
=
"main"
,
long,
env)]
revision
:
String
,
#[clap(default_value
=
"2"
,
long,
env)]
#[clap(default_value
=
"2"
,
long,
env)]
validation_workers
:
usize
,
validation_workers
:
usize
,
#[clap(long,
env)]
#[clap(long,
env)]
...
@@ -66,6 +68,7 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -66,6 +68,7 @@ fn main() -> Result<(), std::io::Error> {
port
,
port
,
master_shard_uds_path
,
master_shard_uds_path
,
tokenizer_name
,
tokenizer_name
,
revision
,
validation_workers
,
validation_workers
,
json_output
,
json_output
,
otlp_endpoint
,
otlp_endpoint
,
...
@@ -90,16 +93,19 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -90,16 +93,19 @@ fn main() -> Result<(), std::io::Error> {
// Tokenizer instance
// Tokenizer instance
// This will only be used to validate payloads
// This will only be used to validate payloads
let
local_path
=
Path
::
new
(
&
tokenizer_name
);
let
local_path
=
Path
::
new
(
&
tokenizer_name
);
let
tokenizer
=
let
local_model
=
local_path
.exists
()
&&
local_path
.is_dir
();
if
local_path
.exists
()
&&
local_path
.is_dir
()
&&
local_path
.join
(
"tokenizer.json"
)
.exists
()
let
tokenizer
=
if
local_model
{
{
// Load local tokenizer
// Load local tokenizer
Tokenizer
::
from_file
(
local_path
.join
(
"tokenizer.json"
))
.ok
()
Tokenizer
::
from_file
(
local_path
.join
(
"tokenizer.json"
))
.ok
()
}
else
{
}
else
{
// Download and instantiate tokenizer
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
// We need to download it outside of the Tokio runtime
let
params
=
FromPretrainedParameters
{
Tokenizer
::
from_pretrained
(
tokenizer_name
.clone
(),
None
)
.ok
()
revision
:
revision
.clone
(),
..
Default
::
default
()
};
};
Tokenizer
::
from_pretrained
(
tokenizer_name
.clone
(),
Some
(
params
))
.ok
()
};
// Launch Tokio runtime
// Launch Tokio runtime
tokio
::
runtime
::
Builder
::
new_multi_thread
()
tokio
::
runtime
::
Builder
::
new_multi_thread
()
...
@@ -116,25 +122,23 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -116,25 +122,23 @@ fn main() -> Result<(), std::io::Error> {
tracing
::
warn!
(
"Rust input length validation and truncation is disabled"
);
tracing
::
warn!
(
"Rust input length validation and truncation is disabled"
);
}
}
// Get pipeline tag
// Get Model info
let
model_info
=
reqwest
::
get
(
format!
(
let
model_info
=
match
local_model
{
"https://huggingface.co/api/models/{tokenizer_name}"
true
=>
ModelInfo
{
))
model_id
:
tokenizer_name
.clone
(),
.await
sha
:
None
,
.expect
(
"Could not connect to hf.co"
)
pipeline_tag
:
None
,
.text
()
},
.await
false
=>
get_model_info
(
&
tokenizer_name
,
&
revision
)
.await
,
.expect
(
"error when retrieving model info from hf.co"
);
};
let
model_info
:
serde_json
::
Value
=
serde_json
::
from_str
(
&
model_info
)
.expect
(
"unable to parse model info"
);
// if pipeline-tag == text-generation we default to return_full_text = true
// if pipeline-tag == text-generation we default to return_full_text = true
let
compat_return_full_text
=
match
model_info
.
get
(
"
pipeline_tag
"
)
{
let
compat_return_full_text
=
match
&
model_info
.pipeline_tag
{
None
=>
{
None
=>
{
tracing
::
warn!
(
"no pipeline tag found for model {tokenizer_name}"
);
tracing
::
warn!
(
"no pipeline tag found for model {tokenizer_name}"
);
false
false
}
}
Some
(
pipeline_tag
)
=>
pipeline_tag
.as_str
()
==
Some
(
"text-generation"
)
,
Some
(
pipeline_tag
)
=>
pipeline_tag
.as_str
()
==
"text-generation"
,
};
};
// Instantiate sharded client from the master unix socket
// Instantiate sharded client from the master unix socket
...
@@ -153,6 +157,7 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -153,6 +157,7 @@ fn main() -> Result<(), std::io::Error> {
// Run server
// Run server
server
::
run
(
server
::
run
(
model_info
,
compat_return_full_text
,
compat_return_full_text
,
max_concurrent_requests
,
max_concurrent_requests
,
max_best_of
,
max_best_of
,
...
@@ -226,3 +231,16 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
...
@@ -226,3 +231,16 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
.with
(
layers
)
.with
(
layers
)
.init
();
.init
();
}
}
/// get model info from the Huggingface Hub
pub
async
fn
get_model_info
(
model_id
:
&
str
,
revision
:
&
str
)
->
ModelInfo
{
let
model_info
=
reqwest
::
get
(
format!
(
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
))
.await
.expect
(
"Could not connect to hf.co"
)
.text
()
.await
.expect
(
"error when retrieving model info from hf.co"
);
serde_json
::
from_str
(
&
model_info
)
.expect
(
"unable to parse model info"
)
}
router/src/server.rs
View file @
2475aede
...
@@ -3,8 +3,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
...
@@ -3,8 +3,8 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use
crate
::
validation
::
ValidationError
;
use
crate
::
validation
::
ValidationError
;
use
crate
::{
use
crate
::{
BestOfSequence
,
CompatGenerateRequest
,
Details
,
ErrorResponse
,
FinishReason
,
BestOfSequence
,
CompatGenerateRequest
,
Details
,
ErrorResponse
,
FinishReason
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
Infer
,
PrefillToken
,
StreamDetails
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
Infer
,
Info
,
ModelInfo
,
PrefillToken
,
StreamResponse
,
Token
,
Validation
,
StreamDetails
,
StreamResponse
,
Token
,
Validation
,
};
};
use
axum
::
extract
::
Extension
;
use
axum
::
extract
::
Extension
;
use
axum
::
http
::{
HeaderMap
,
Method
,
StatusCode
};
use
axum
::
http
::{
HeaderMap
,
Method
,
StatusCode
};
...
@@ -27,7 +27,24 @@ use tracing::{info_span, instrument, Instrument};
...
@@ -27,7 +27,24 @@ use tracing::{info_span, instrument, Instrument};
use
utoipa
::
OpenApi
;
use
utoipa
::
OpenApi
;
use
utoipa_swagger_ui
::
SwaggerUi
;
use
utoipa_swagger_ui
::
SwaggerUi
;
/// Compatibility route with api-inference and AzureML
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path(
post,
tag
=
"Text Generation Inference"
,
path
=
"/"
,
request_body
=
CompatGenerateRequest,
responses(
(status
=
200
,
description
=
"See /generate or /generate_stream"
),
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
)),
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
)),
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Input validation error"
}
)),
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
)),
)
)]
#[instrument(skip(infer))]
#[instrument(skip(infer))]
async
fn
compat_generate
(
async
fn
compat_generate
(
default_return_full_text
:
Extension
<
bool
>
,
default_return_full_text
:
Extension
<
bool
>
,
...
@@ -53,6 +70,26 @@ async fn compat_generate(
...
@@ -53,6 +70,26 @@ async fn compat_generate(
}
}
}
}
/// Text Generation Inference endpoint info
#[utoipa::path(
get,
tag
=
"Text Generation Inference"
,
path
=
"/info"
,
responses((status
=
200
,
description
=
"Served model info"
,
body
=
Info))
)]
#[instrument]
async
fn
get_model_info
(
model_info
:
Extension
<
ModelInfo
>
)
->
Json
<
Info
>
{
let
model_info
=
model_info
.0
;
let
info
=
Info
{
version
:
env!
(
"CARGO_PKG_VERSION"
),
sha
:
option_env!
(
"VERGEN_GIT_SHA"
),
model_id
:
model_info
.model_id
,
model_sha
:
model_info
.sha
,
model_pipeline_tag
:
model_info
.pipeline_tag
,
};
Json
(
info
)
}
/// Health check method
/// Health check method
#[instrument(skip(infer))]
#[instrument(skip(infer))]
async
fn
health
(
infer
:
Extension
<
Infer
>
)
->
Result
<
(),
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
async
fn
health
(
infer
:
Extension
<
Infer
>
)
->
Result
<
(),
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
...
@@ -87,21 +124,21 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
...
@@ -87,21 +124,21 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
/// Generate tokens
/// Generate tokens
#[utoipa::path(
#[utoipa::path(
post,
post,
tag
=
"Text Generation Inference"
,
tag
=
"Text Generation Inference"
,
path
=
"/generate"
,
path
=
"/generate"
,
request_body
=
GenerateRequest,
request_body
=
GenerateRequest,
responses(
responses(
(status
=
200
,
description
=
"Generated Text"
,
body
=
GenerateResponse),
(status
=
200
,
description
=
"Generated Text"
,
body
=
GenerateResponse),
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
)),
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
)),
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
)),
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
)),
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Input validation error"
}
)),
example
=
json
!
(
{
"error"
:
"Input validation error"
}
)),
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
)),
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
)),
)
)
)]
)]
#[instrument(
#[instrument(
skip(infer),
skip(infer),
...
@@ -264,26 +301,26 @@ async fn generate(
...
@@ -264,26 +301,26 @@ async fn generate(
/// Generate a stream of token using Server-Sent Events
/// Generate a stream of token using Server-Sent Events
#[utoipa::path(
#[utoipa::path(
post,
post,
tag
=
"Text Generation Inference"
,
tag
=
"Text Generation Inference"
,
path
=
"/generate_stream"
,
path
=
"/generate_stream"
,
request_body
=
GenerateRequest,
request_body
=
GenerateRequest,
responses(
responses(
(status
=
200
,
description
=
"Generated Text"
,
body
=
StreamResponse,
(status
=
200
,
description
=
"Generated Text"
,
body
=
StreamResponse,
content_type
=
"text/event-stream"
),
content_type
=
"text/event-stream"
),
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
),
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
),
content_type
=
"text/event-stream"
),
content_type
=
"text/event-stream"
),
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
),
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
),
content_type
=
"text/event-stream"
),
content_type
=
"text/event-stream"
),
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Input validation error"
}
),
example
=
json
!
(
{
"error"
:
"Input validation error"
}
),
content_type
=
"text/event-stream"
),
content_type
=
"text/event-stream"
),
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
),
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
),
content_type
=
"text/event-stream"
),
content_type
=
"text/event-stream"
),
)
)
)]
)]
#[instrument(
#[instrument(
skip(infer),
skip(infer),
...
@@ -447,10 +484,10 @@ async fn generate_stream(
...
@@ -447,10 +484,10 @@ async fn generate_stream(
/// Prometheus metrics scrape endpoint
/// Prometheus metrics scrape endpoint
#[utoipa::path(
#[utoipa::path(
get,
get,
tag
=
"Text Generation Inference"
,
tag
=
"Text Generation Inference"
,
path
=
"/metrics"
,
path
=
"/metrics"
,
responses((status
=
200
,
description
=
"Prometheus Metrics"
,
body
=
String))
responses((status
=
200
,
description
=
"Prometheus Metrics"
,
body
=
String))
)]
)]
async
fn
metrics
(
prom_handle
:
Extension
<
PrometheusHandle
>
)
->
String
{
async
fn
metrics
(
prom_handle
:
Extension
<
PrometheusHandle
>
)
->
String
{
prom_handle
.render
()
prom_handle
.render
()
...
@@ -459,6 +496,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
...
@@ -459,6 +496,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
/// Serving method
/// Serving method
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
pub
async
fn
run
(
pub
async
fn
run
(
model_info
:
ModelInfo
,
compat_return_full_text
:
bool
,
compat_return_full_text
:
bool
,
max_concurrent_requests
:
usize
,
max_concurrent_requests
:
usize
,
max_best_of
:
usize
,
max_best_of
:
usize
,
...
@@ -476,36 +514,40 @@ pub async fn run(
...
@@ -476,36 +514,40 @@ pub async fn run(
// OpenAPI documentation
// OpenAPI documentation
#[derive(OpenApi)]
#[derive(OpenApi)]
#[openapi(
#[openapi(
paths(
paths(
generate,
get_model_info,
generate_stream,
compat_generate,
metrics,
generate,
),
generate_stream,
components(
metrics,
schemas(
),
GenerateRequest,
components(
GenerateParameters,
schemas(
PrefillToken,
Info,
Token,
CompatGenerateRequest,
GenerateResponse,
GenerateRequest,
BestOfSequence,
GenerateParameters,
Details,
PrefillToken,
FinishReason,
Token,
StreamResponse,
GenerateResponse,
StreamDetails,
BestOfSequence,
ErrorResponse,
Details,
)
FinishReason,
),
StreamResponse,
tags(
StreamDetails,
(name
=
"Text Generation Inference"
,
description
=
"Hugging Face Text Generation Inference API"
)
ErrorResponse,
),
)
info(
),
title
=
"Text Generation Inference"
,
tags(
license(
(name
=
"Text Generation Inference"
,
description
=
"Hugging Face Text Generation Inference API"
)
name
=
"Apache 2.0"
,
),
url
=
"https://www.apache.org/licenses/LICENSE-2.0"
info(
)
title
=
"Text Generation Inference"
,
)
license(
name
=
"Apache 2.0"
,
url
=
"https://www.apache.org/licenses/LICENSE-2.0"
)
)
)]
)]
struct
ApiDoc
;
struct
ApiDoc
;
...
@@ -584,6 +626,7 @@ pub async fn run(
...
@@ -584,6 +626,7 @@ pub async fn run(
.merge
(
SwaggerUi
::
new
(
"/docs"
)
.url
(
"/api-doc/openapi.json"
,
ApiDoc
::
openapi
()))
.merge
(
SwaggerUi
::
new
(
"/docs"
)
.url
(
"/api-doc/openapi.json"
,
ApiDoc
::
openapi
()))
// Base routes
// Base routes
.route
(
"/"
,
post
(
compat_generate
))
.route
(
"/"
,
post
(
compat_generate
))
.route
(
"/info"
,
get
(
get_model_info
))
.route
(
"/generate"
,
post
(
generate
))
.route
(
"/generate"
,
post
(
generate
))
.route
(
"/generate_stream"
,
post
(
generate_stream
))
.route
(
"/generate_stream"
,
post
(
generate_stream
))
// AWS Sagemaker route
// AWS Sagemaker route
...
@@ -596,6 +639,7 @@ pub async fn run(
...
@@ -596,6 +639,7 @@ pub async fn run(
.route
(
"/ping"
,
get
(
health
))
.route
(
"/ping"
,
get
(
health
))
// Prometheus metrics route
// Prometheus metrics route
.route
(
"/metrics"
,
get
(
metrics
))
.route
(
"/metrics"
,
get
(
metrics
))
.layer
(
Extension
(
model_info
))
.layer
(
Extension
(
compat_return_full_text
))
.layer
(
Extension
(
compat_return_full_text
))
.layer
(
Extension
(
infer
))
.layer
(
Extension
(
infer
))
.layer
(
Extension
(
prom_handle
))
.layer
(
Extension
(
prom_handle
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment