server.rs 17.9 KB
Newer Older
1
2
/// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
3
use crate::{
4
    Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
5
    Infer, PrefillToken, StreamDetails, StreamResponse, Token, Validation,
6
};
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use axum::extract::Extension;
8
use axum::http::{HeaderMap, Method, StatusCode};
9
use axum::response::sse::{Event, KeepAlive, Sse};
10
use axum::response::IntoResponse;
Olivier Dehaene's avatar
Olivier Dehaene committed
11
use axum::routing::{get, post};
12
use axum::{http, Json, Router};
13
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
14
use futures::Stream;
15
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
16
use std::convert::Infallible;
Olivier Dehaene's avatar
Olivier Dehaene committed
17
use std::net::SocketAddr;
18
use text_generation_client::ShardedClient;
Olivier Dehaene's avatar
Olivier Dehaene committed
19
use tokenizers::Tokenizer;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
20
use tokio::signal;
Olivier Dehaene's avatar
Olivier Dehaene committed
21
use tokio::time::Instant;
22
use tokio_stream::StreamExt;
23
use tower_http::cors::{AllowOrigin, CorsLayer};
24
use tracing::{info_span, instrument, Instrument};
25
26
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
Olivier Dehaene's avatar
Olivier Dehaene committed
27

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
28
/// Health check method
29
30
#[instrument(skip(infer))]
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
31
32
    // TODO: while this is the best health check we can do, it is a bit on the heavy side and might
    //       be a bit too slow for a health check.
33
    //       What we should do instead is check if the gRPC channels are still healthy.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
34
35

    // Send a small inference request
36
37
38
39
    infer
        .generate(GenerateRequest {
            inputs: "liveness".to_string(),
            parameters: GenerateParameters {
40
41
42
43
                temperature: None,
                repetition_penalty: None,
                top_k: None,
                top_p: None,
44
45
                do_sample: false,
                max_new_tokens: 1,
46
                stop: Vec::new(),
47
48
                details: false,
                seed: None,
Olivier Dehaene's avatar
Olivier Dehaene committed
49
            },
50
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
51
52
        .await?;
    Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
53
54
}

55
56
57
58
59
60
61
/// Generate tokens
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/generate",
    request_body = GenerateRequest,
    responses(
62
63
        (status = 200, description = "Generated Text", body = GenerateResponse),
        (status = 424, description = "Generation Error", body = ErrorResponse,
64
            example = json!({"error": "Request failed during generation"})),
65
        (status = 429, description = "Model is overloaded", body = ErrorResponse,
66
            example = json!({"error": "Model is overloaded"})),
67
        (status = 422, description = "Input validation error", body = ErrorResponse,
68
            example = json!({"error": "Input validation error"})),
69
        (status = 500, description = "Incomplete generation", body = ErrorResponse,
70
71
72
            example = json!({"error": "Incomplete generation"})),
    )
)]
73
#[instrument(
74
    skip(infer),
75
76
77
78
79
    fields(
        total_time,
        validation_time,
        queue_time,
        inference_time,
80
        time_per_token,
81
        seed,
82
83
    )
)]
Olivier Dehaene's avatar
Olivier Dehaene committed
84
async fn generate(
85
    infer: Extension<Infer>,
Olivier Dehaene's avatar
Olivier Dehaene committed
86
    req: Json<GenerateRequest>,
87
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
88
    let span = tracing::Span::current();
89
    let start_time = Instant::now();
90
91

    // Inference
92
    let details = req.0.parameters.details;
93
    let response = infer.generate(req.0).await?;
Olivier Dehaene's avatar
Olivier Dehaene committed
94

OlivierDehaene's avatar
OlivierDehaene committed
95
96
    // Token details
    let details = match details {
97
        true => Some(Details {
98
            finish_reason: FinishReason::from(response.generated_text.finish_reason),
99
100
101
102
103
            generated_tokens: response.generated_text.generated_tokens,
            prefill: Some(response.prefill),
            tokens: Some(response.tokens),
            seed: response.generated_text.seed,
        }),
OlivierDehaene's avatar
OlivierDehaene committed
104
105
106
        false => None,
    };

107
108
109
110
    // Timings
    let total_time = start_time.elapsed();
    let validation_time = response.queued - start_time;
    let queue_time = response.start - response.queued;
111
112
    let inference_time = Instant::now() - response.start;
    let time_per_token = inference_time / response.generated_text.generated_tokens;
113
114
115
116
117
118
119
120
121
122
123
124
125
126

    // Headers
    let mut headers = HeaderMap::new();
    headers.insert(
        "x-total-time",
        total_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-validation-time",
        validation_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-queue-time",
        queue_time.as_millis().to_string().parse().unwrap(),
Olivier Dehaene's avatar
Olivier Dehaene committed
127
    );
128
129
130
131
132
133
134
135
136
137
    headers.insert(
        "x-inference-time",
        inference_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-time-per-token",
        time_per_token.as_millis().to_string().parse().unwrap(),
    );

    // Tracing metadata
138
139
140
141
142
    span.record("total_time", format!("{total_time:?}"));
    span.record("validation_time", format!("{validation_time:?}"));
    span.record("queue_time", format!("{queue_time:?}"));
    span.record("inference_time", format!("{inference_time:?}"));
    span.record("time_per_token", format!("{time_per_token:?}"));
143
144
    span.record("seed", format!("{:?}", response.generated_text.seed));
    tracing::info!("Output: {}", response.generated_text.text);
Olivier Dehaene's avatar
Olivier Dehaene committed
145

146
147
148
149
150
151
152
153
154
155
156
157
    // Metrics
    metrics::increment_counter!("tgi_request_success");
    metrics::histogram!("tgi_request_duration", total_time);
    metrics::histogram!("tgi_request_validation_duration", validation_time);
    metrics::histogram!("tgi_request_queue_duration", queue_time);
    metrics::histogram!("tgi_request_inference_duration", inference_time);
    metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
    metrics::histogram!(
        "tgi_request_generated_tokens",
        response.generated_text.generated_tokens as f64
    );

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
158
    // Send response
159
    let response = GenerateResponse {
160
        generated_text: response.generated_text.text,
OlivierDehaene's avatar
OlivierDehaene committed
161
        details,
162
    };
163
    Ok((headers, Json(response)))
Olivier Dehaene's avatar
Olivier Dehaene committed
164
165
}

Yannic Kilcher's avatar
Yannic Kilcher committed
166
/// Generate a stream of token using Server-Sent Events
167
168
169
170
171
172
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/generate_stream",
    request_body = GenerateRequest,
    responses(
173
174
175
        (status = 200, description = "Generated Text", body = StreamResponse,
            content_type="text/event-stream"),
        (status = 424, description = "Generation Error", body = ErrorResponse,
176
            example = json!({"error": "Request failed during generation"}),
177
178
            content_type="text/event-stream"),
        (status = 429, description = "Model is overloaded", body = ErrorResponse,
179
            example = json!({"error": "Model is overloaded"}),
180
181
            content_type="text/event-stream"),
        (status = 422, description = "Input validation error", body = ErrorResponse,
182
            example = json!({"error": "Input validation error"}),
183
184
            content_type="text/event-stream"),
        (status = 500, description = "Incomplete generation", body = ErrorResponse,
185
            example = json!({"error": "Incomplete generation"}),
186
            content_type="text/event-stream"),
187
188
    )
)]
189
190
191
192
193
194
195
#[instrument(
    skip(infer),
    fields(
        total_time,
        validation_time,
        queue_time,
        inference_time,
196
197
        time_per_token,
        seed,
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    )
)]
async fn generate_stream(
    infer: Extension<Infer>,
    req: Json<GenerateRequest>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    let span = tracing::Span::current();
    let start_time = Instant::now();

    let stream = async_stream::stream! {
        // Inference
        let mut end_reached = false;
        let mut error = false;
        let details = req.0.parameters.details;

213
        match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
214
            Ok(mut response_stream) => {
Yannic Kilcher's avatar
Yannic Kilcher committed
215
                // Server-Sent Event stream
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
                while let Some(response) = response_stream.next().await {
                    match response {
                        Ok(response) => {
                            match response {
                                // Prefill is ignored
                                InferStreamResponse::Prefill(_) => {}
                                // Yield event for every new token
                                InferStreamResponse::Token(token) => {
                                    // StreamResponse
                                    let stream_token = StreamResponse {
                                        token,
                                        generated_text: None,
                                        details: None,
                                    };

                                    yield Ok(Event::default().json_data(stream_token).unwrap())
                                }
                                // Yield event for last token and compute timings
                                InferStreamResponse::End {
                                    token,
                                    generated_text,
                                    start,
                                    queued,
                                } => {
                                    // Token details
                                    let details = match details {
242
243
                                        true => Some(StreamDetails {
                                            finish_reason: FinishReason::from(generated_text.finish_reason),
244
245
246
247
248
249
250
251
252
253
254
255
256
257
                                            generated_tokens: generated_text.generated_tokens,
                                            seed: generated_text.seed,
                                        }),
                                        false => None,
                                    };

                                    // Timings
                                    let total_time = start_time.elapsed();
                                    let validation_time = queued - start_time;
                                    let queue_time = start - queued;
                                    let inference_time = Instant::now() - start;
                                    let time_per_token = inference_time / generated_text.generated_tokens;

                                    // Tracing metadata
258
259
260
261
262
                                    span.record("total_time", format!("{total_time:?}"));
                                    span.record("validation_time", format!("{validation_time:?}"));
                                    span.record("queue_time", format!("{queue_time:?}"));
                                    span.record("inference_time", format!("{inference_time:?}"));
                                    span.record("time_per_token", format!("{time_per_token:?}"));
263
                                    span.record("seed", format!("{:?}", generated_text.seed));
264
265
                                    tracing::info!(parent: &span, "Output: {}", generated_text.text);

266
267
268
269
270
271
272
273
274
                                    // Metrics
                                    metrics::increment_counter!("tgi_request_success");
                                    metrics::histogram!("tgi_request_duration", total_time);
                                    metrics::histogram!("tgi_request_validation_duration", validation_time);
                                    metrics::histogram!("tgi_request_queue_duration", queue_time);
                                    metrics::histogram!("tgi_request_inference_duration", inference_time);
                                    metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
                                    metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);

275
276
277
278
279
280
281
282
283
284
285
286
                                    // StreamResponse
                                    end_reached = true;
                                    let stream_token = StreamResponse {
                                        token,
                                        generated_text: Some(generated_text.text),
                                        details
                                    };

                                    yield Ok(Event::default().json_data(stream_token).unwrap())
                                }
                            }
                        }
287
                        // yield error
288
289
290
291
292
293
294
                        Err(err) => {
                            error = true;
                            yield Ok(Event::from(err))
                        }
                    }
                }
            },
295
            // yield error
296
297
298
299
300
301
302
303
304
            Err(err) => {
                error = true;
                yield Ok(Event::from(err))
            }
        }
        // Check if generation reached the end
        // Skip if we already sent an error
        if !end_reached && !error {
            let err = InferError::IncompleteGeneration;
305
            metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
306
            tracing::error!("{err}");
307
308
309
310
311
312
313
            yield Ok(Event::from(err))
        }
    };

    Sse::new(stream).keep_alive(KeepAlive::default())
}

314
315
316
317
318
319
320
321
322
323
324
/// Prometheus metrics scrape endpoint
#[utoipa::path(
    get,
    tag = "Text Generation Inference",
    path = "/metrics",
    responses((status = 200, description = "Prometheus Metrics", body = String))
)]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
    prom_handle.render()
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
325
326
327
328
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
    max_concurrent_requests: usize,
329
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
330
    max_input_length: usize,
331
    max_total_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
332
    max_batch_size: usize,
333
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
334
335
336
337
    client: ShardedClient,
    tokenizer: Tokenizer,
    validation_workers: usize,
    addr: SocketAddr,
338
    allow_origin: Option<AllowOrigin>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
339
) {
340
341
342
343
344
345
    // OpenAPI documentation
    #[derive(OpenApi)]
    #[openapi(
        paths(
            generate,
            generate_stream,
346
            metrics,
347
348
349
350
351
        ),
        components(
            schemas(
                GenerateRequest,
                GenerateParameters,
352
                PrefillToken,
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
                Token,
                GenerateResponse,
                Details,
                FinishReason,
                StreamResponse,
                StreamDetails,
                ErrorResponse,
            )
        ),
        tags(
            (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
        ),
        info(
            title = "Text Generation Inference",
            license(
                name = "Apache 2.0",
                url = "https://www.apache.org/licenses/LICENSE-2.0"
            )
        )
    )]
    struct ApiDoc;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
375
    // Create state
376
377
378
379
380
381
382
    let validation = Validation::new(
        validation_workers,
        tokenizer,
        max_stop_sequences,
        max_input_length,
        max_total_tokens,
    );
383
384
    let infer = Infer::new(
        client,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
385
        validation,
386
387
388
389
        max_batch_size,
        max_waiting_tokens,
        max_concurrent_requests,
    );
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
390

391
392
393
394
395
396
    // Prometheus handler
    let builder = PrometheusBuilder::new();
    let prom_handle = builder
        .install_recorder()
        .expect("failed to install metrics recorder");

397
398
399
400
401
402
403
    // CORS layer
    let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
    let cors_layer = CorsLayer::new()
        .allow_methods([Method::GET, Method::POST])
        .allow_headers([http::header::CONTENT_TYPE])
        .allow_origin(allow_origin);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
404
    // Create router
Olivier Dehaene's avatar
Olivier Dehaene committed
405
    let app = Router::new()
406
        .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
407
        .route("/", post(generate))
Olivier Dehaene's avatar
Olivier Dehaene committed
408
        .route("/generate", post(generate))
409
        .route("/generate_stream", post(generate_stream))
410
        .route("/", get(health))
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
411
        .route("/health", get(health))
412
        .layer(Extension(infer))
413
414
        .route("/metrics", get(metrics))
        .layer(Extension(prom_handle))
415
416
        .layer(opentelemetry_tracing_layer())
        .layer(cors_layer);
Olivier Dehaene's avatar
Olivier Dehaene committed
417

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
418
    // Run server
Olivier Dehaene's avatar
Olivier Dehaene committed
419
    axum::Server::bind(&addr)
Olivier Dehaene's avatar
Olivier Dehaene committed
420
        .serve(app.into_make_service())
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
421
422
        // Wait until all requests are finished to shut down
        .with_graceful_shutdown(shutdown_signal())
Olivier Dehaene's avatar
Olivier Dehaene committed
423
424
        .await
        .unwrap();
Olivier Dehaene's avatar
Olivier Dehaene committed
425
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451

/// Shutdown signal handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {},
        _ = terminate => {},
    }

    tracing::info!("signal received, starting graceful shutdown");
452
    opentelemetry::global::shutdown_tracer_provider();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
453
}
454

455
456
457
458
459
460
461
462
463
464
465
impl From<i32> for FinishReason {
    fn from(finish_reason: i32) -> Self {
        let finish_reason = text_generation_client::FinishReason::from_i32(finish_reason).unwrap();
        match finish_reason {
            text_generation_client::FinishReason::Length => FinishReason::Length,
            text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
            text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
        }
    }
}

466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
/// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
    fn from(err: InferError) -> Self {
        let status_code = match err {
            InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
            InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
            InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
            InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
        };

        (
            status_code,
            Json(ErrorResponse {
                error: err.to_string(),
            }),
        )
    }
}

impl From<InferError> for Event {
    fn from(err: InferError) -> Self {
        Event::default()
            .json_data(ErrorResponse {
                error: err.to_string(),
            })
            .unwrap()
    }
}