Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
4e685d90
Unverified
Commit
4e685d90
authored
Feb 28, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 28, 2023
Browse files
feat(router): ask hf.co for pipelinetag to decide on compat_return_full_text (#89)
parent
21340f24
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
87 additions
and
29 deletions
+87
-29
Cargo.lock
Cargo.lock
+1
-0
router/Cargo.toml
router/Cargo.toml
+1
-0
router/src/lib.rs
router/src/lib.rs
+6
-6
router/src/main.rs
router/src/main.rs
+23
-1
router/src/server.rs
router/src/server.rs
+56
-22
No files found.
Cargo.lock
View file @
4e685d90
...
...
@@ -2268,6 +2268,7 @@ dependencies = [
"opentelemetry-otlp",
"parking_lot",
"rand",
"reqwest",
"serde",
"serde_json",
"text-generation-client",
...
...
router/Cargo.toml
View file @
4e685d90
...
...
@@ -26,6 +26,7 @@ opentelemetry = { version = "0.18.0", features = ["rt-tokio"] }
opentelemetry-otlp
=
"0.11.0"
parking_lot
=
"0.12.1"
rand
=
"0.8.5"
reqwest
=
{
version
=
"0.11.14"
,
features
=
[]
}
serde
=
"1.0.152"
serde_json
=
"1.0.93"
thiserror
=
"1.0.38"
...
...
router/src/lib.rs
View file @
4e685d90
...
...
@@ -40,13 +40,16 @@ pub(crate) struct GenerateParameters {
example
=
0.95
)]
pub
top_p
:
Option
<
f32
>
,
#[serde(default
=
"default_do_sample"
)]
#[serde(default)]
#[schema(default
=
"false"
,
example
=
true
)]
pub
do_sample
:
bool
,
#[serde(default
=
"default_max_new_tokens"
)]
#[schema(exclusive_minimum
=
0
,
exclusive_maximum
=
512
,
default
=
"20"
)]
pub
max_new_tokens
:
u32
,
#[serde(default)]
#[schema(default
=
"None"
,
example
=
false
)]
pub
return_full_text
:
Option
<
bool
>
,
#[serde(default)]
#[schema(inline,
max_items
=
4
,
example
=
json
!
(
[
"photographer"
]
))]
pub
stop
:
Vec
<
String
>
,
#[serde(default)]
...
...
@@ -56,10 +59,6 @@ pub(crate) struct GenerateParameters {
pub
seed
:
Option
<
u64
>
,
}
fn
default_do_sample
()
->
bool
{
false
}
fn
default_max_new_tokens
()
->
u32
{
20
}
...
...
@@ -70,8 +69,9 @@ fn default_parameters() -> GenerateParameters {
repetition_penalty
:
None
,
top_k
:
None
,
top_p
:
None
,
do_sample
:
default_do_sample
()
,
do_sample
:
false
,
max_new_tokens
:
default_max_new_tokens
(),
return_full_text
:
None
,
stop
:
vec!
[],
details
:
false
,
seed
:
None
,
...
...
router/src/main.rs
View file @
4e685d90
...
...
@@ -87,7 +87,7 @@ fn main() -> Result<(), std::io::Error> {
// This will only be used to validate payloads
//
// We need to download it outside of the Tokio runtime
let
tokenizer
=
Tokenizer
::
from_pretrained
(
tokenizer_name
,
None
)
.unwrap
();
let
tokenizer
=
Tokenizer
::
from_pretrained
(
tokenizer_name
.clone
()
,
None
)
.unwrap
();
// Launch Tokio runtime
tokio
::
runtime
::
Builder
::
new_multi_thread
()
...
...
@@ -97,6 +97,27 @@ fn main() -> Result<(), std::io::Error> {
.block_on
(
async
{
init_logging
(
otlp_endpoint
,
json_output
);
// Get pipeline tag
let
model_info
=
reqwest
::
get
(
format!
(
"https://huggingface.co/api/models/{tokenizer_name}"
))
.await
.expect
(
"Could not connect to hf.co"
)
.text
()
.await
.expect
(
"error when retrieving model info from hf.co"
);
let
model_info
:
serde_json
::
Value
=
serde_json
::
from_str
(
&
model_info
)
.expect
(
"unable to parse model info"
);
// if pipeline-tag == text-generation we default to return_full_text = true
let
compat_return_full_text
=
match
model_info
.get
(
"pipeline_tag"
)
{
None
=>
{
tracing
::
warn!
(
"no pipeline tag found for model {tokenizer_name}"
);
false
}
Some
(
pipeline_tag
)
=>
pipeline_tag
.as_str
()
==
Some
(
"text-generation"
),
};
// Instantiate sharded client from the master unix socket
let
mut
sharded_client
=
ShardedClient
::
connect_uds
(
master_shard_uds_path
)
.await
...
...
@@ -113,6 +134,7 @@ fn main() -> Result<(), std::io::Error> {
// Run server
server
::
run
(
compat_return_full_text
,
max_concurrent_requests
,
max_stop_sequences
,
max_input_length
,
...
...
router/src/server.rs
View file @
4e685d90
...
...
@@ -29,11 +29,18 @@ use utoipa_swagger_ui::SwaggerUi;
/// Compatibility route with api-inference and AzureML
#[instrument(skip(infer))]
async
fn
compat_generate
(
default_return_full_text
:
Extension
<
bool
>
,
infer
:
Extension
<
Infer
>
,
req
:
Json
<
CompatGenerateRequest
>
,
)
->
Result
<
impl
IntoResponse
,
(
StatusCode
,
Json
<
ErrorResponse
>
)
>
{
let
mut
req
=
req
.0
;
// default return_full_text given the pipeline_tag
if
req
.parameters.return_full_text
.is_none
()
{
req
.parameters.return_full_text
=
Some
(
default_return_full_text
.0
)
}
// switch on stream
let
req
=
req
.0
;
if
req
.stream
{
Ok
(
generate_stream
(
infer
,
Json
(
req
.into
()))
.await
...
...
@@ -63,6 +70,7 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
top_p
:
None
,
do_sample
:
false
,
max_new_tokens
:
1
,
return_full_text
:
None
,
stop
:
Vec
::
new
(),
details
:
false
,
seed
:
None
,
...
...
@@ -81,13 +89,13 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
responses(
(status
=
200
,
description
=
"Generated Text"
,
body
=
GenerateResponse),
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
)),
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
)),
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
)),
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
)),
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Input validation error"
}
)),
example
=
json
!
(
{
"error"
:
"Input validation error"
}
)),
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
)),
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
)),
)
)]
#[instrument(
...
...
@@ -108,8 +116,14 @@ async fn generate(
let
span
=
tracing
::
Span
::
current
();
let
start_time
=
Instant
::
now
();
// Inference
let
mut
add_prompt
=
None
;
if
req
.0
.parameters.return_full_text
.unwrap_or
(
false
)
{
add_prompt
=
Some
(
req
.0
.inputs
.clone
());
}
let
details
=
req
.0
.parameters.details
;
// Inference
let
response
=
infer
.generate
(
req
.0
)
.await
?
;
// Token details
...
...
@@ -176,8 +190,13 @@ async fn generate(
);
// Send response
let
mut
output_text
=
response
.generated_text.text
;
if
let
Some
(
prompt
)
=
add_prompt
{
output_text
=
prompt
+
&
output_text
;
}
let
response
=
GenerateResponse
{
generated_text
:
response
.generated_text.
text
,
generated_text
:
output_
text
,
details
,
};
Ok
((
headers
,
Json
(
response
)))
...
...
@@ -191,19 +210,19 @@ async fn generate(
request_body
=
GenerateRequest,
responses(
(status
=
200
,
description
=
"Generated Text"
,
body
=
StreamResponse,
content_type
=
"text/event-stream"
),
content_type
=
"text/event-stream"
),
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
),
content_type
=
"text/event-stream"
),
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
),
content_type
=
"text/event-stream"
),
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
),
content_type
=
"text/event-stream"
),
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
),
content_type
=
"text/event-stream"
),
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Input validation error"
}
),
content_type
=
"text/event-stream"
),
example
=
json
!
(
{
"error"
:
"Input validation error"
}
),
content_type
=
"text/event-stream"
),
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
),
content_type
=
"text/event-stream"
),
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
),
content_type
=
"text/event-stream"
),
)
)]
#[instrument(
...
...
@@ -228,6 +247,11 @@ async fn generate_stream(
// Inference
let
mut
end_reached
=
false
;
let
mut
error
=
false
;
let
mut
add_prompt
=
None
;
if
req
.0
.parameters.return_full_text
.unwrap_or
(
false
)
{
add_prompt
=
Some
(
req
.0
.inputs
.clone
());
}
let
details
=
req
.0
.parameters.details
;
match
infer
.generate_stream
(
req
.0
)
.instrument
(
info_span!
(
parent
:
&
span
,
"async_stream"
))
.await
{
...
...
@@ -294,20 +318,28 @@ async fn generate_stream(
// StreamResponse
end_reached
=
true
;
let
mut
output_text
=
generated_text
.text
;
if
let
Some
(
prompt
)
=
add_prompt
{
output_text
=
prompt
+
&
output_text
;
}
let
stream_token
=
StreamResponse
{
token
,
generated_text
:
Some
(
generated_text
.
text
),
generated_text
:
Some
(
output_
text
),
details
};
yield
Ok
(
Event
::
default
()
.json_data
(
stream_token
)
.unwrap
())
yield
Ok
(
Event
::
default
()
.json_data
(
stream_token
)
.unwrap
());
break
;
}
}
}
// yield error
Err
(
err
)
=>
{
error
=
true
;
yield
Ok
(
Event
::
from
(
err
))
yield
Ok
(
Event
::
from
(
err
));
break
;
}
}
}
...
...
@@ -315,7 +347,7 @@ async fn generate_stream(
// yield error
Err
(
err
)
=>
{
error
=
true
;
yield
Ok
(
Event
::
from
(
err
))
yield
Ok
(
Event
::
from
(
err
))
;
}
}
// Check if generation reached the end
...
...
@@ -324,7 +356,7 @@ async fn generate_stream(
let
err
=
InferError
::
IncompleteGeneration
;
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"incomplete"
);
tracing
::
error!
(
"{err}"
);
yield
Ok
(
Event
::
from
(
err
))
yield
Ok
(
Event
::
from
(
err
))
;
}
};
...
...
@@ -345,6 +377,7 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
/// Serving method
#[allow(clippy::too_many_arguments)]
pub
async
fn
run
(
compat_return_full_text
:
bool
,
max_concurrent_requests
:
usize
,
max_stop_sequences
:
usize
,
max_input_length
:
usize
,
...
...
@@ -429,8 +462,9 @@ pub async fn run(
.route
(
"/generate_stream"
,
post
(
generate_stream
))
.route
(
"/"
,
get
(
health
))
.route
(
"/health"
,
get
(
health
))
.layer
(
Extension
(
infer
))
.route
(
"/metrics"
,
get
(
metrics
))
.layer
(
Extension
(
compat_return_full_text
))
.layer
(
Extension
(
infer
))
.layer
(
Extension
(
prom_handle
))
.layer
(
opentelemetry_tracing_layer
())
.layer
(
cors_layer
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment