Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
fbeb1c44
Unverified
Commit
fbeb1c44
authored
Jan 10, 2024
by
OlivierDehaene
Committed by
GitHub
Jan 10, 2024
Browse files
fix: follow base model for tokenizer in router (#1424)
Close #1422
parent
91d72675
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
185 additions
and
159 deletions
+185
-159
Cargo.lock
Cargo.lock
+4
-0
router/Cargo.toml
router/Cargo.toml
+1
-1
router/src/main.rs
router/src/main.rs
+180
-158
No files found.
Cargo.lock
View file @
fbeb1c44
...
@@ -950,13 +950,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
...
@@ -950,13 +950,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732"
dependencies = [
dependencies = [
"dirs 5.0.1",
"dirs 5.0.1",
"futures",
"indicatif",
"indicatif",
"log",
"log",
"native-tls",
"native-tls",
"num_cpus",
"rand",
"rand",
"reqwest",
"serde",
"serde",
"serde_json",
"serde_json",
"thiserror",
"thiserror",
"tokio",
"ureq",
"ureq",
]
]
...
...
router/Cargo.toml
View file @
fbeb1c44
...
@@ -21,6 +21,7 @@ axum-tracing-opentelemetry = "0.14.1"
...
@@ -21,6 +21,7 @@ axum-tracing-opentelemetry = "0.14.1"
text-generation-client
=
{
path
=
"client"
}
text-generation-client
=
{
path
=
"client"
}
clap
=
{
version
=
"4.4.5"
,
features
=
[
"derive"
,
"env"
]
}
clap
=
{
version
=
"4.4.5"
,
features
=
[
"derive"
,
"env"
]
}
futures
=
"0.3.28"
futures
=
"0.3.28"
hf-hub
=
{
version
=
"0.3.0"
,
features
=
["tokio"]
}
metrics
=
"0.21.1"
metrics
=
"0.21.1"
metrics-exporter-prometheus
=
{
version
=
"0.12.1"
,
features
=
[]
}
metrics-exporter-prometheus
=
{
version
=
"0.12.1"
,
features
=
[]
}
nohash-hasher
=
"0.2.0"
nohash-hasher
=
"0.2.0"
...
@@ -41,7 +42,6 @@ tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
...
@@ -41,7 +42,6 @@ tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
utoipa
=
{
version
=
"3.5.0"
,
features
=
["axum_extras"]
}
utoipa
=
{
version
=
"3.5.0"
,
features
=
["axum_extras"]
}
utoipa-swagger-ui
=
{
version
=
"3.1.5"
,
features
=
["axum"]
}
utoipa-swagger-ui
=
{
version
=
"3.1.5"
,
features
=
["axum"]
}
ngrok
=
{
version
=
"0.13.1"
,
features
=
["axum"]
,
optional
=
true
}
ngrok
=
{
version
=
"0.13.1"
,
features
=
["axum"]
,
optional
=
true
}
hf-hub
=
"0.3.1"
init-tracing-opentelemetry
=
{
version
=
"0.14.1"
,
features
=
["opentelemetry-otlp"]
}
init-tracing-opentelemetry
=
{
version
=
"0.14.1"
,
features
=
["opentelemetry-otlp"]
}
[build-dependencies]
[build-dependencies]
...
...
router/src/main.rs
View file @
fbeb1c44
/// Text Generation Inference webserver entrypoint
use
axum
::
http
::
HeaderValue
;
use
axum
::
http
::
HeaderValue
;
use
clap
::
Parser
;
use
clap
::
Parser
;
use
hf_hub
::
api
::
tokio
::{
Api
,
ApiBuilder
,
ApiRepo
};
use
hf_hub
::{
Repo
,
RepoType
};
use
opentelemetry
::
sdk
::
propagation
::
TraceContextPropagator
;
use
opentelemetry
::
sdk
::
propagation
::
TraceContextPropagator
;
use
opentelemetry
::
sdk
::
trace
;
use
opentelemetry
::
sdk
::
trace
;
use
opentelemetry
::
sdk
::
trace
::
Sampler
;
use
opentelemetry
::
sdk
::
trace
::
Sampler
;
use
opentelemetry
::
sdk
::
Resource
;
use
opentelemetry
::
sdk
::
Resource
;
use
opentelemetry
::{
global
,
KeyValue
};
use
opentelemetry
::{
global
,
KeyValue
};
use
opentelemetry_otlp
::
WithExportConfig
;
use
opentelemetry_otlp
::
WithExportConfig
;
/// Text Generation Inference webserver entrypoint
use
std
::
fs
::
File
;
use
std
::
io
::
BufReader
;
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
path
::
Path
;
use
std
::
path
::
Path
;
use
std
::
time
::
Duration
;
use
text_generation_client
::{
ClientError
,
ShardedClient
};
use
text_generation_client
::{
ClientError
,
ShardedClient
};
use
text_generation_router
::{
server
,
HubModelInfo
};
use
text_generation_router
::{
server
,
HubModelInfo
};
use
thiserror
::
Error
;
use
thiserror
::
Error
;
use
tokenizers
::
{
FromPretrainedParameters
,
Tokenizer
}
;
use
tokenizers
::
Tokenizer
;
use
tower_http
::
cors
::
AllowOrigin
;
use
tower_http
::
cors
::
AllowOrigin
;
use
tracing_subscriber
::
layer
::
SubscriberExt
;
use
tracing_subscriber
::
layer
::
SubscriberExt
;
use
tracing_subscriber
::
util
::
SubscriberInitExt
;
use
tracing_subscriber
::
util
::
SubscriberInitExt
;
...
@@ -69,7 +72,8 @@ struct Args {
...
@@ -69,7 +72,8 @@ struct Args {
ngrok_edge
:
Option
<
String
>
,
ngrok_edge
:
Option
<
String
>
,
}
}
fn
main
()
->
Result
<
(),
RouterError
>
{
#[tokio::main]
async
fn
main
()
->
Result
<
(),
RouterError
>
{
// Get args
// Get args
let
args
=
Args
::
parse
();
let
args
=
Args
::
parse
();
// Pattern match configuration
// Pattern match configuration
...
@@ -98,6 +102,9 @@ fn main() -> Result<(), RouterError> {
...
@@ -98,6 +102,9 @@ fn main() -> Result<(), RouterError> {
ngrok_edge
,
ngrok_edge
,
}
=
args
;
}
=
args
;
// Launch Tokio runtime
init_logging
(
otlp_endpoint
,
json_output
);
// Validate args
// Validate args
if
max_input_length
>=
max_total_tokens
{
if
max_input_length
>=
max_total_tokens
{
return
Err
(
RouterError
::
ArgumentValidation
(
return
Err
(
RouterError
::
ArgumentValidation
(
...
@@ -141,53 +148,63 @@ fn main() -> Result<(), RouterError> {
...
@@ -141,53 +148,63 @@ fn main() -> Result<(), RouterError> {
// This will only be used to validate payloads
// This will only be used to validate payloads
let
local_path
=
Path
::
new
(
&
tokenizer_name
);
let
local_path
=
Path
::
new
(
&
tokenizer_name
);
let
local_model
=
local_path
.exists
()
&&
local_path
.is_dir
();
let
local_model
=
local_path
.exists
()
&&
local_path
.is_dir
();
let
tokenizer
=
if
local_model
{
let
(
tokenizer
,
model_info
)
=
if
local_model
{
// Get Model info
let
model_info
=
HubModelInfo
{
model_id
:
tokenizer_name
.clone
(),
sha
:
None
,
pipeline_tag
:
None
,
};
// Load local tokenizer
// Load local tokenizer
Tokenizer
::
from_file
(
local_path
.join
(
"tokenizer.json"
))
.ok
()
let
tokenizer
=
Tokenizer
::
from_file
(
local_path
.join
(
"tokenizer.json"
))
.ok
();
(
tokenizer
,
model_info
)
}
else
{
}
else
{
// Download and instantiate tokenizer
let
mut
builder
=
ApiBuilder
::
new
()
// We need to download it outside of the Tokio runtime
.with_progress
(
false
)
let
params
=
FromPretrainedParameters
{
.with_token
(
authorization_token
);
revision
:
revision
.clone
()
.unwrap_or
(
"main"
.to_string
()),
auth_token
:
authorization_token
.clone
(),
..
Default
::
default
()
};
Tokenizer
::
from_pretrained
(
tokenizer_name
.clone
(),
Some
(
params
))
.ok
()
};
// Launch Tokio runtime
if
let
Some
(
cache_dir
)
=
std
::
env
::
var
(
"HUGGINGFACE_HUB_CACHE"
)
.ok
()
{
tokio
::
runtime
::
Builder
::
new_multi_thread
()
builder
=
builder
.with_cache_dir
(
cache_dir
.into
());
.enable_all
()
}
.build
()
?
.block_on
(
async
{
init_logging
(
otlp_endpoint
,
json_output
);
if
tokenizer
.is_none
()
{
if
revision
.is_none
()
{
tracing
::
warn!
(
tracing
::
warn!
(
"`--revision` is not set"
);
"Could not find a fast tokenizer implementation for {tokenizer_name}"
tracing
::
warn!
(
"We strongly advise to set it to a known supported commit."
);
);
tracing
::
warn!
(
"Rust input length validation and truncation is disabled"
);
}
}
let
api
=
builder
.build
()
.unwrap
();
let
api_repo
=
api
.repo
(
Repo
::
with_revision
(
tokenizer_name
.clone
(),
RepoType
::
Model
,
revision
.clone
()
.unwrap_or
(
"main"
.to_string
()),
));
// Get Model info
// Get Model info
let
model_info
=
match
local_model
{
let
model_info
=
get_model_info
(
&
api_repo
)
.await
.unwrap_or_else
(||
{
true
=>
HubModelInfo
{
model_id
:
tokenizer_name
.clone
(),
sha
:
None
,
pipeline_tag
:
None
,
},
false
=>
get_model_info
(
&
tokenizer_name
,
revision
,
authorization_token
)
.await
.unwrap_or_else
(||
{
tracing
::
warn!
(
"Could not retrieve model info from the Hugging Face hub."
);
tracing
::
warn!
(
"Could not retrieve model info from the Hugging Face hub."
);
HubModelInfo
{
HubModelInfo
{
model_id
:
tokenizer_name
.to_string
(),
model_id
:
tokenizer_name
.to_string
(),
sha
:
None
,
sha
:
None
,
pipeline_tag
:
None
,
pipeline_tag
:
None
,
}
}
}),
});
let
tokenizer
=
match
api_repo
.get
(
"tokenizer.json"
)
.await
{
Ok
(
tokenizer_filename
)
=>
Tokenizer
::
from_file
(
tokenizer_filename
)
.ok
(),
Err
(
_
)
=>
get_base_tokenizer
(
&
api
,
&
api_repo
)
.await
,
};
(
tokenizer
,
model_info
)
};
};
if
tokenizer
.is_none
()
{
tracing
::
warn!
(
"Could not find a fast tokenizer implementation for {tokenizer_name}"
);
tracing
::
warn!
(
"Rust input length validation and truncation is disabled"
);
}
// if pipeline-tag == text-generation we default to return_full_text = true
// if pipeline-tag == text-generation we default to return_full_text = true
let
compat_return_full_text
=
match
&
model_info
.pipeline_tag
{
let
compat_return_full_text
=
match
&
model_info
.pipeline_tag
{
None
=>
{
None
=>
{
...
@@ -212,15 +229,18 @@ fn main() -> Result<(), RouterError> {
...
@@ -212,15 +229,18 @@ fn main() -> Result<(), RouterError> {
// Warmup model
// Warmup model
tracing
::
info!
(
"Warming up model"
);
tracing
::
info!
(
"Warming up model"
);
let
max_supported_batch_total_tokens
=
match
sharded_client
let
max_supported_batch_total_tokens
=
match
sharded_client
.warmup
(
max_input_length
as
u32
,
max_batch_prefill_tokens
,
max_total_tokens
as
u32
)
.warmup
(
max_input_length
as
u32
,
max_batch_prefill_tokens
,
max_total_tokens
as
u32
,
)
.await
.await
.map_err
(
RouterError
::
Warmup
)
?
.map_err
(
RouterError
::
Warmup
)
?
{
{
// Older models do not support automatic max-batch-total-tokens
// Older models do not support automatic max-batch-total-tokens
None
=>
{
None
=>
{
let
max_batch_total_tokens
=
max_batch_total_tokens
.unwrap_or
(
let
max_batch_total_tokens
=
max_batch_total_tokens
16000
.max
((
max_total_tokens
as
u32
)
.max
(
max_batch_prefill_tokens
)),
.unwrap_or
(
16000
.max
((
max_total_tokens
as
u32
)
.max
(
max_batch_prefill_tokens
)));
);
tracing
::
warn!
(
"Model does not support automatic max batch total tokens"
);
tracing
::
warn!
(
"Model does not support automatic max batch total tokens"
);
max_batch_total_tokens
max_batch_total_tokens
}
}
...
@@ -280,7 +300,6 @@ fn main() -> Result<(), RouterError> {
...
@@ -280,7 +300,6 @@ fn main() -> Result<(), RouterError> {
)
)
.await
?
;
.await
?
;
Ok
(())
Ok
(())
})
}
}
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
...
@@ -339,30 +358,8 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
...
@@ -339,30 +358,8 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
}
}
/// get model info from the Huggingface Hub
/// get model info from the Huggingface Hub
pub
async
fn
get_model_info
(
pub
async
fn
get_model_info
(
api
:
&
ApiRepo
)
->
Option
<
HubModelInfo
>
{
model_id
:
&
str
,
let
response
=
api
.info_request
()
.send
()
.await
.ok
()
?
;
revision
:
Option
<
String
>
,
token
:
Option
<
String
>
,
)
->
Option
<
HubModelInfo
>
{
let
revision
=
match
revision
{
None
=>
{
tracing
::
warn!
(
"`--revision` is not set"
);
tracing
::
warn!
(
"We strongly advise to set it to a known supported commit."
);
"main"
.to_string
()
}
Some
(
revision
)
=>
revision
,
};
let
client
=
reqwest
::
Client
::
new
();
// Poor man's urlencode
let
revision
=
revision
.replace
(
'/'
,
"%2F"
);
let
url
=
format!
(
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
);
let
mut
builder
=
client
.get
(
url
)
.timeout
(
Duration
::
from_secs
(
5
));
if
let
Some
(
token
)
=
token
{
builder
=
builder
.bearer_auth
(
token
);
}
let
response
=
builder
.send
()
.await
.ok
()
?
;
if
response
.status
()
.is_success
()
{
if
response
.status
()
.is_success
()
{
let
hub_model_info
:
HubModelInfo
=
let
hub_model_info
:
HubModelInfo
=
...
@@ -379,6 +376,31 @@ pub async fn get_model_info(
...
@@ -379,6 +376,31 @@ pub async fn get_model_info(
}
}
}
}
/// get base tokenizer
pub
async
fn
get_base_tokenizer
(
api
:
&
Api
,
api_repo
:
&
ApiRepo
)
->
Option
<
Tokenizer
>
{
let
config_filename
=
api_repo
.get
(
"config.json"
)
.await
.ok
()
?
;
// Open the file in read-only mode with buffer.
let
file
=
File
::
open
(
config_filename
)
.ok
()
?
;
let
reader
=
BufReader
::
new
(
file
);
// Read the JSON contents of the file as an instance of `User`.
let
config
:
serde_json
::
Value
=
serde_json
::
from_reader
(
reader
)
.ok
()
?
;
if
let
Some
(
serde_json
::
Value
::
String
(
base_model_id
))
=
config
.get
(
"base_model_name_or_path"
)
{
let
api_base_repo
=
api
.repo
(
Repo
::
with_revision
(
base_model_id
.to_string
(),
RepoType
::
Model
,
"main"
.to_string
(),
));
let
tokenizer_filename
=
api_base_repo
.get
(
"tokenizer.json"
)
.await
.ok
()
?
;
Tokenizer
::
from_file
(
tokenizer_filename
)
.ok
()
}
else
{
None
}
}
#[derive(Debug,
Error)]
#[derive(Debug,
Error)]
enum
RouterError
{
enum
RouterError
{
#[error(
"Argument validation error: {0}"
)]
#[error(
"Argument validation error: {0}"
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment