Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
1da642bd
Commit
1da642bd
authored
Jul 21, 2023
by
OlivierDehaene
Browse files
feat(server): add local prom and health routes if running w/ ngrok
parent
15b3e9ff
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
142 additions
and
127 deletions
+142
-127
router/src/server.rs
router/src/server.rs
+142
-127
No files found.
router/src/server.rs
View file @
1da642bd
...
@@ -32,25 +32,25 @@ use utoipa_swagger_ui::SwaggerUi;
...
@@ -32,25 +32,25 @@ use utoipa_swagger_ui::SwaggerUi;
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path(
#[utoipa::path(
post,
post,
tag
=
"Text Generation Inference"
,
tag
=
"Text Generation Inference"
,
path
=
"/"
,
path
=
"/"
,
request_body
=
CompatGenerateRequest,
request_body
=
CompatGenerateRequest,
responses(
responses(
(status
=
200
,
description
=
"Generated Text"
,
(status
=
200
,
description
=
"Generated Text"
,
content(
content(
(
"application/json"
=
GenerateResponse),
(
"application/json"
=
GenerateResponse),
(
"text/event-stream"
=
StreamResponse),
(
"text/event-stream"
=
StreamResponse),
)),
)),
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
)),
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
)),
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
)),
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
)),
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Input validation error"
}
)),
example
=
json
!
(
{
"error"
:
"Input validation error"
}
)),
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
)),
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
)),
)
)
)]
)]
#[instrument(skip(infer,
req))]
#[instrument(skip(infer,
req))]
async
fn
compat_generate
(
async
fn
compat_generate
(
...
@@ -79,10 +79,10 @@ async fn compat_generate(
...
@@ -79,10 +79,10 @@ async fn compat_generate(
/// Text Generation Inference endpoint info
/// Text Generation Inference endpoint info
#[utoipa::path(
#[utoipa::path(
get,
get,
tag
=
"Text Generation Inference"
,
tag
=
"Text Generation Inference"
,
path
=
"/info"
,
path
=
"/info"
,
responses((status
=
200
,
description
=
"Served model info"
,
body
=
Info))
responses((status
=
200
,
description
=
"Served model info"
,
body
=
Info))
)]
)]
#[instrument]
#[instrument]
async
fn
get_model_info
(
info
:
Extension
<
Info
>
)
->
Json
<
Info
>
{
async
fn
get_model_info
(
info
:
Extension
<
Info
>
)
->
Json
<
Info
>
{
...
@@ -90,14 +90,14 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
...
@@ -90,14 +90,14 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
}
}
#[utoipa::path(
#[utoipa::path(
get,
get,
tag
=
"Text Generation Inference"
,
tag
=
"Text Generation Inference"
,
path
=
"/health"
,
path
=
"/health"
,
responses(
responses(
(status
=
200
,
description
=
"Everything is working fine"
),
(status
=
200
,
description
=
"Everything is working fine"
),
(status
=
503
,
description
=
"Text generation inference is down"
,
body
=
ErrorResponse,
(status
=
503
,
description
=
"Text generation inference is down"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"unhealthy"
,
"error_type"
:
"healthcheck"
}
)),
example
=
json
!
(
{
"error"
:
"unhealthy"
,
"error_type"
:
"healthcheck"
}
)),
)
)
)]
)]
#[instrument(skip(health))]
#[instrument(skip(health))]
/// Health check method
/// Health check method
...
@@ -116,33 +116,33 @@ async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<E
...
@@ -116,33 +116,33 @@ async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<E
/// Generate tokens
/// Generate tokens
#[utoipa::path(
#[utoipa::path(
post,
post,
tag
=
"Text Generation Inference"
,
tag
=
"Text Generation Inference"
,
path
=
"/generate"
,
path
=
"/generate"
,
request_body
=
GenerateRequest,
request_body
=
GenerateRequest,
responses(
responses(
(status
=
200
,
description
=
"Generated Text"
,
body
=
GenerateResponse),
(status
=
200
,
description
=
"Generated Text"
,
body
=
GenerateResponse),
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
)),
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
)),
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
)),
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
)),
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Input validation error"
}
)),
example
=
json
!
(
{
"error"
:
"Input validation error"
}
)),
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
)),
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
)),
)
)
)]
)]
#[instrument(
#[instrument(
skip_all,
skip_all,
fields(
fields(
parameters
=
?
req
.
0
.
parameters,
parameters
=
?
req
.
0
.
parameters,
total_time,
total_time,
validation_time,
validation_time,
queue_time,
queue_time,
inference_time,
inference_time,
time_per_token,
time_per_token,
seed,
seed,
)
)
)]
)]
async
fn
generate
(
async
fn
generate
(
infer
:
Extension
<
Infer
>
,
infer
:
Extension
<
Infer
>
,
...
@@ -297,38 +297,38 @@ async fn generate(
...
@@ -297,38 +297,38 @@ async fn generate(
/// Generate a stream of token using Server-Sent Events
/// Generate a stream of token using Server-Sent Events
#[utoipa::path(
#[utoipa::path(
post,
post,
tag
=
"Text Generation Inference"
,
tag
=
"Text Generation Inference"
,
path
=
"/generate_stream"
,
path
=
"/generate_stream"
,
request_body
=
GenerateRequest,
request_body
=
GenerateRequest,
responses(
responses(
(status
=
200
,
description
=
"Generated Text"
,
body
=
StreamResponse,
(status
=
200
,
description
=
"Generated Text"
,
body
=
StreamResponse,
content_type
=
"text/event-stream"
),
content_type
=
"text/event-stream"
),
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
),
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
),
content_type
=
"text/event-stream"
),
content_type
=
"text/event-stream"
),
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
),
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
),
content_type
=
"text/event-stream"
),
content_type
=
"text/event-stream"
),
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Input validation error"
}
),
example
=
json
!
(
{
"error"
:
"Input validation error"
}
),
content_type
=
"text/event-stream"
),
content_type
=
"text/event-stream"
),
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
),
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
),
content_type
=
"text/event-stream"
),
content_type
=
"text/event-stream"
),
)
)
)]
)]
#[instrument(
#[instrument(
skip_all,
skip_all,
fields(
fields(
parameters
=
?
req
.
0
.
parameters,
parameters
=
?
req
.
0
.
parameters,
total_time,
total_time,
validation_time,
validation_time,
queue_time,
queue_time,
inference_time,
inference_time,
time_per_token,
time_per_token,
seed,
seed,
)
)
)]
)]
async
fn
generate_stream
(
async
fn
generate_stream
(
infer
:
Extension
<
Infer
>
,
infer
:
Extension
<
Infer
>
,
...
@@ -493,10 +493,10 @@ async fn generate_stream(
...
@@ -493,10 +493,10 @@ async fn generate_stream(
/// Prometheus metrics scrape endpoint
/// Prometheus metrics scrape endpoint
#[utoipa::path(
#[utoipa::path(
get,
get,
tag
=
"Text Generation Inference"
,
tag
=
"Text Generation Inference"
,
path
=
"/metrics"
,
path
=
"/metrics"
,
responses((status
=
200
,
description
=
"Prometheus Metrics"
,
body
=
String))
responses((status
=
200
,
description
=
"Prometheus Metrics"
,
body
=
String))
)]
)]
async
fn
metrics
(
prom_handle
:
Extension
<
PrometheusHandle
>
)
->
String
{
async
fn
metrics
(
prom_handle
:
Extension
<
PrometheusHandle
>
)
->
String
{
prom_handle
.render
()
prom_handle
.render
()
...
@@ -683,10 +683,10 @@ pub async fn run(
...
@@ -683,10 +683,10 @@ pub async fn run(
// Prometheus metrics route
// Prometheus metrics route
.route
(
"/metrics"
,
get
(
metrics
))
.route
(
"/metrics"
,
get
(
metrics
))
.layer
(
Extension
(
info
))
.layer
(
Extension
(
info
))
.layer
(
Extension
(
health_ext
))
.layer
(
Extension
(
health_ext
.clone
()
))
.layer
(
Extension
(
compat_return_full_text
))
.layer
(
Extension
(
compat_return_full_text
))
.layer
(
Extension
(
infer
))
.layer
(
Extension
(
infer
))
.layer
(
Extension
(
prom_handle
))
.layer
(
Extension
(
prom_handle
.clone
()
))
.layer
(
opentelemetry_tracing_layer
())
.layer
(
opentelemetry_tracing_layer
())
.layer
(
cors_layer
);
.layer
(
cors_layer
);
...
@@ -712,6 +712,21 @@ pub async fn run(
...
@@ -712,6 +712,21 @@ pub async fn run(
let
listener
=
tunnel
.listen
()
.await
.unwrap
();
let
listener
=
tunnel
.listen
()
.await
.unwrap
();
// Run prom metrics and health locally too
tokio
::
spawn
(
axum
::
Server
::
bind
(
&
addr
)
.serve
(
Router
::
new
()
.route
(
"/health"
,
get
(
health
))
.route
(
"/metrics"
,
get
(
metrics
))
.layer
(
Extension
(
health_ext
))
.layer
(
Extension
(
prom_handle
))
.into_make_service
(),
)
//Wait until all requests are finished to shut down
.with_graceful_shutdown
(
shutdown_signal
()),
);
// Run server
// Run server
axum
::
Server
::
builder
(
listener
)
axum
::
Server
::
builder
(
listener
)
.serve
(
app
.into_make_service
())
.serve
(
app
.into_make_service
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment