Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
fbeb1c44
You need to sign in or sign up before continuing.
Unverified
Commit
fbeb1c44
authored
Jan 10, 2024
by
OlivierDehaene
Committed by
GitHub
Jan 10, 2024
Browse files
fix: follow base model for tokenizer in router (#1424)
Close #1422
parent
91d72675
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
185 additions
and
159 deletions
+185
-159
Cargo.lock
Cargo.lock
+4
-0
router/Cargo.toml
router/Cargo.toml
+1
-1
router/src/main.rs
router/src/main.rs
+180
-158
No files found.
Cargo.lock
View file @
fbeb1c44
...
@@ -950,13 +950,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
...
@@ -950,13 +950,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
dependencies = [
dependencies = [
"dirs 5.0.1",
"dirs 5.0.1",
"futures",
"indicatif",
"indicatif",
"log",
"log",
"native-tls",
"native-tls",
"num_cpus",
"rand",
"rand",
"reqwest",
"serde",
"serde",
"serde_json",
"serde_json",
"thiserror",
"thiserror",
"tokio",
"ureq",
"ureq",
]
]
...
...
router/Cargo.toml
View file @
fbeb1c44
...
@@ -21,6 +21,7 @@ axum-tracing-opentelemetry = "0.14.1"
...
@@ -21,6 +21,7 @@ axum-tracing-opentelemetry = "0.14.1"
text-generation-client
=
{
path
=
"client"
}
text-generation-client
=
{
path
=
"client"
}
clap
=
{
version
=
"4.4.5"
,
features
=
[
"derive"
,
"env"
]
}
clap
=
{
version
=
"4.4.5"
,
features
=
[
"derive"
,
"env"
]
}
futures
=
"0.3.28"
futures
=
"0.3.28"
hf-hub
=
{
version
=
"0.3.0"
,
features
=
["tokio"]
}
metrics
=
"0.21.1"
metrics
=
"0.21.1"
metrics-exporter-prometheus
=
{
version
=
"0.12.1"
,
features
=
[]
}
metrics-exporter-prometheus
=
{
version
=
"0.12.1"
,
features
=
[]
}
nohash-hasher
=
"0.2.0"
nohash-hasher
=
"0.2.0"
...
@@ -41,7 +42,6 @@ tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
...
@@ -41,7 +42,6 @@ tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
utoipa
=
{
version
=
"3.5.0"
,
features
=
["axum_extras"]
}
utoipa
=
{
version
=
"3.5.0"
,
features
=
["axum_extras"]
}
utoipa-swagger-ui
=
{
version
=
"3.1.5"
,
features
=
["axum"]
}
utoipa-swagger-ui
=
{
version
=
"3.1.5"
,
features
=
["axum"]
}
ngrok
=
{
version
=
"0.13.1"
,
features
=
["axum"]
,
optional
=
true
}
ngrok
=
{
version
=
"0.13.1"
,
features
=
["axum"]
,
optional
=
true
}
hf-hub
=
"0.3.1"
init-tracing-opentelemetry
=
{
version
=
"0.14.1"
,
features
=
["opentelemetry-otlp"]
}
init-tracing-opentelemetry
=
{
version
=
"0.14.1"
,
features
=
["opentelemetry-otlp"]
}
[build-dependencies]
[build-dependencies]
...
...
router/src/main.rs
View file @
fbeb1c44
/// Text Generation Inference webserver entrypoint
use
axum
::
http
::
HeaderValue
;
use
axum
::
http
::
HeaderValue
;
use
clap
::
Parser
;
use
clap
::
Parser
;
use
hf_hub
::
api
::
tokio
::{
Api
,
ApiBuilder
,
ApiRepo
};
use
hf_hub
::{
Repo
,
RepoType
};
use
opentelemetry
::
sdk
::
propagation
::
TraceContextPropagator
;
use
opentelemetry
::
sdk
::
propagation
::
TraceContextPropagator
;
use
opentelemetry
::
sdk
::
trace
;
use
opentelemetry
::
sdk
::
trace
;
use
opentelemetry
::
sdk
::
trace
::
Sampler
;
use
opentelemetry
::
sdk
::
trace
::
Sampler
;
use
opentelemetry
::
sdk
::
Resource
;
use
opentelemetry
::
sdk
::
Resource
;
use
opentelemetry
::{
global
,
KeyValue
};
use
opentelemetry
::{
global
,
KeyValue
};
use
opentelemetry_otlp
::
WithExportConfig
;
use
opentelemetry_otlp
::
WithExportConfig
;
/// Text Generation Inference webserver entrypoint
use
std
::
fs
::
File
;
use
std
::
io
::
BufReader
;
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
path
::
Path
;
use
std
::
path
::
Path
;
use
std
::
time
::
Duration
;
use
text_generation_client
::{
ClientError
,
ShardedClient
};
use
text_generation_client
::{
ClientError
,
ShardedClient
};
use
text_generation_router
::{
server
,
HubModelInfo
};
use
text_generation_router
::{
server
,
HubModelInfo
};
use
thiserror
::
Error
;
use
thiserror
::
Error
;
use
tokenizers
::
{
FromPretrainedParameters
,
Tokenizer
}
;
use
tokenizers
::
Tokenizer
;
use
tower_http
::
cors
::
AllowOrigin
;
use
tower_http
::
cors
::
AllowOrigin
;
use
tracing_subscriber
::
layer
::
SubscriberExt
;
use
tracing_subscriber
::
layer
::
SubscriberExt
;
use
tracing_subscriber
::
util
::
SubscriberInitExt
;
use
tracing_subscriber
::
util
::
SubscriberInitExt
;
...
@@ -69,7 +72,8 @@ struct Args {
...
@@ -69,7 +72,8 @@ struct Args {
ngrok_edge
:
Option
<
String
>
,
ngrok_edge
:
Option
<
String
>
,
}
}
fn
main
()
->
Result
<
(),
RouterError
>
{
#[tokio::main]
async
fn
main
()
->
Result
<
(),
RouterError
>
{
// Get args
// Get args
let
args
=
Args
::
parse
();
let
args
=
Args
::
parse
();
// Pattern match configuration
// Pattern match configuration
...
@@ -98,6 +102,9 @@ fn main() -> Result<(), RouterError> {
...
@@ -98,6 +102,9 @@ fn main() -> Result<(), RouterError> {
ngrok_edge
,
ngrok_edge
,
}
=
args
;
}
=
args
;
// Launch Tokio runtime
init_logging
(
otlp_endpoint
,
json_output
);
// Validate args
// Validate args
if
max_input_length
>=
max_total_tokens
{
if
max_input_length
>=
max_total_tokens
{
return
Err
(
RouterError
::
ArgumentValidation
(
return
Err
(
RouterError
::
ArgumentValidation
(
...
@@ -141,53 +148,63 @@ fn main() -> Result<(), RouterError> {
...
@@ -141,53 +148,63 @@ fn main() -> Result<(), RouterError> {
// This will only be used to validate payloads
// This will only be used to validate payloads
let
local_path
=
Path
::
new
(
&
tokenizer_name
);
let
local_path
=
Path
::
new
(
&
tokenizer_name
);
let
local_model
=
local_path
.exists
()
&&
local_path
.is_dir
();
let
local_model
=
local_path
.exists
()
&&
local_path
.is_dir
();
let
tokenizer
=
if
local_model
{
let
(
tokenizer
,
model_info
)
=
if
local_model
{
// Get Model info
let
model_info
=
HubModelInfo
{
model_id
:
tokenizer_name
.clone
(),
sha
:
None
,
pipeline_tag
:
None
,
};
// Load local tokenizer
// Load local tokenizer
Tokenizer
::
from_file
(
local_path
.join
(
"tokenizer.json"
))
.ok
()
let
tokenizer
=
Tokenizer
::
from_file
(
local_path
.join
(
"tokenizer.json"
))
.ok
();
(
tokenizer
,
model_info
)
}
else
{
}
else
{
// Download and instantiate tokenizer
let
mut
builder
=
ApiBuilder
::
new
()
// We need to download it outside of the Tokio runtime
.with_progress
(
false
)
let
params
=
FromPretrainedParameters
{
.with_token
(
authorization_token
);
revision
:
revision
.clone
()
.unwrap_or
(
"main"
.to_string
()),
auth_token
:
authorization_token
.clone
(),
..
Default
::
default
()
};
Tokenizer
::
from_pretrained
(
tokenizer_name
.clone
(),
Some
(
params
))
.ok
()
};
// Launch Tokio runtime
if
let
Some
(
cache_dir
)
=
std
::
env
::
var
(
"HUGGINGFACE_HUB_CACHE"
)
.ok
()
{
tokio
::
runtime
::
Builder
::
new_multi_thread
()
builder
=
builder
.with_cache_dir
(
cache_dir
.into
());
.enable_all
()
}
.build
()
?
.block_on
(
async
{
init_logging
(
otlp_endpoint
,
json_output
);
if
tokenizer
.is_none
()
{
if
revision
.is_none
()
{
tracing
::
warn!
(
tracing
::
warn!
(
"`--revision` is not set"
);
"Could not find a fast tokenizer implementation for {tokenizer_name}"
tracing
::
warn!
(
"We strongly advise to set it to a known supported commit."
);
);
tracing
::
warn!
(
"Rust input length validation and truncation is disabled"
);
}
}
let
api
=
builder
.build
()
.unwrap
();
let
api_repo
=
api
.repo
(
Repo
::
with_revision
(
tokenizer_name
.clone
(),
RepoType
::
Model
,
revision
.clone
()
.unwrap_or
(
"main"
.to_string
()),
));
// Get Model info
// Get Model info
let
model_info
=
match
local_model
{
let
model_info
=
get_model_info
(
&
api_repo
)
.await
.unwrap_or_else
(||
{
true
=>
HubModelInfo
{
model_id
:
tokenizer_name
.clone
(),
sha
:
None
,
pipeline_tag
:
None
,
},
false
=>
get_model_info
(
&
tokenizer_name
,
revision
,
authorization_token
)
.await
.unwrap_or_else
(||
{
tracing
::
warn!
(
"Could not retrieve model info from the Hugging Face hub."
);
tracing
::
warn!
(
"Could not retrieve model info from the Hugging Face hub."
);
HubModelInfo
{
HubModelInfo
{
model_id
:
tokenizer_name
.to_string
(),
model_id
:
tokenizer_name
.to_string
(),
sha
:
None
,
sha
:
None
,
pipeline_tag
:
None
,
pipeline_tag
:
None
,
}
}
}),
});
let
tokenizer
=
match
api_repo
.get
(
"tokenizer.json"
)
.await
{
Ok
(
tokenizer_filename
)
=>
Tokenizer
::
from_file
(
tokenizer_filename
)
.ok
(),
Err
(
_
)
=>
get_base_tokenizer
(
&
api
,
&
api_repo
)
.await
,
};
(
tokenizer
,
model_info
)
};
};
if
tokenizer
.is_none
()
{
tracing
::
warn!
(
"Could not find a fast tokenizer implementation for {tokenizer_name}"
);
tracing
::
warn!
(
"Rust input length validation and truncation is disabled"
);
}
// if pipeline-tag == text-generation we default to return_full_text = true
// if pipeline-tag == text-generation we default to return_full_text = true
let
compat_return_full_text
=
match
&
model_info
.pipeline_tag
{
let
compat_return_full_text
=
match
&
model_info
.pipeline_tag
{
None
=>
{
None
=>
{
...
@@ -212,15 +229,18 @@ fn main() -> Result<(), RouterError> {
...
@@ -212,15 +229,18 @@ fn main() -> Result<(), RouterError> {
// Warmup model
// Warmup model
tracing
::
info!
(
"Warming up model"
);
tracing
::
info!
(
"Warming up model"
);
let
max_supported_batch_total_tokens
=
match
sharded_client
let
max_supported_batch_total_tokens
=
match
sharded_client
.warmup
(
max_input_length
as
u32
,
max_batch_prefill_tokens
,
max_total_tokens
as
u32
)
.warmup
(
max_input_length
as
u32
,
max_batch_prefill_tokens
,
max_total_tokens
as
u32
,
)
.await
.await
.map_err
(
RouterError
::
Warmup
)
?
.map_err
(
RouterError
::
Warmup
)
?
{
{
// Older models do not support automatic max-batch-total-tokens
// Older models do not support automatic max-batch-total-tokens
None
=>
{
None
=>
{
let
max_batch_total_tokens
=
max_batch_total_tokens
.unwrap_or
(
let
max_batch_total_tokens
=
max_batch_total_tokens
16000
.max
((
max_total_tokens
as
u32
)
.max
(
max_batch_prefill_tokens
)),
.unwrap_or
(
16000
.max
((
max_total_tokens
as
u32
)
.max
(
max_batch_prefill_tokens
)));
);
tracing
::
warn!
(
"Model does not support automatic max batch total tokens"
);
tracing
::
warn!
(
"Model does not support automatic max batch total tokens"
);
max_batch_total_tokens
max_batch_total_tokens
}
}
...
@@ -280,7 +300,6 @@ fn main() -> Result<(), RouterError> {
...
@@ -280,7 +300,6 @@ fn main() -> Result<(), RouterError> {
)
)
.await
?
;
.await
?
;
Ok
(())
Ok
(())
})
}
}
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
...
@@ -339,30 +358,8 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
...
@@ -339,30 +358,8 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
}
}
/// get model info from the Huggingface Hub
/// get model info from the Huggingface Hub
pub
async
fn
get_model_info
(
pub
async
fn
get_model_info
(
api
:
&
ApiRepo
)
->
Option
<
HubModelInfo
>
{
model_id
:
&
str
,
let
response
=
api
.info_request
()
.send
()
.await
.ok
()
?
;
revision
:
Option
<
String
>
,
token
:
Option
<
String
>
,
)
->
Option
<
HubModelInfo
>
{
let
revision
=
match
revision
{
None
=>
{
tracing
::
warn!
(
"`--revision` is not set"
);
tracing
::
warn!
(
"We strongly advise to set it to a known supported commit."
);
"main"
.to_string
()
}
Some
(
revision
)
=>
revision
,
};
let
client
=
reqwest
::
Client
::
new
();
// Poor man's urlencode
let
revision
=
revision
.replace
(
'/'
,
"%2F"
);
let
url
=
format!
(
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
);
let
mut
builder
=
client
.get
(
url
)
.timeout
(
Duration
::
from_secs
(
5
));
if
let
Some
(
token
)
=
token
{
builder
=
builder
.bearer_auth
(
token
);
}
let
response
=
builder
.send
()
.await
.ok
()
?
;
if
response
.status
()
.is_success
()
{
if
response
.status
()
.is_success
()
{
let
hub_model_info
:
HubModelInfo
=
let
hub_model_info
:
HubModelInfo
=
...
@@ -379,6 +376,31 @@ pub async fn get_model_info(
...
@@ -379,6 +376,31 @@ pub async fn get_model_info(
}
}
}
}
/// get base tokenizer
pub
async
fn
get_base_tokenizer
(
api
:
&
Api
,
api_repo
:
&
ApiRepo
)
->
Option
<
Tokenizer
>
{
let
config_filename
=
api_repo
.get
(
"config.json"
)
.await
.ok
()
?
;
// Open the file in read-only mode with buffer.
let
file
=
File
::
open
(
config_filename
)
.ok
()
?
;
let
reader
=
BufReader
::
new
(
file
);
// Read the JSON contents of the file as an instance of `User`.
let
config
:
serde_json
::
Value
=
serde_json
::
from_reader
(
reader
)
.ok
()
?
;
if
let
Some
(
serde_json
::
Value
::
String
(
base_model_id
))
=
config
.get
(
"base_model_name_or_path"
)
{
let
api_base_repo
=
api
.repo
(
Repo
::
with_revision
(
base_model_id
.to_string
(),
RepoType
::
Model
,
"main"
.to_string
(),
));
let
tokenizer_filename
=
api_base_repo
.get
(
"tokenizer.json"
)
.await
.ok
()
?
;
Tokenizer
::
from_file
(
tokenizer_filename
)
.ok
()
}
else
{
None
}
}
#[derive(Debug,
Error)]
#[derive(Debug,
Error)]
enum
RouterError
{
enum
RouterError
{
#[error(
"Argument validation error: {0}"
)]
#[error(
"Argument validation error: {0}"
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment