server.rs 17.4 KB
Newer Older
1
2
/// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
3
use crate::{
4
5
    Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse,
    Infer, StreamDetails, StreamResponse, Token, Validation,
6
};
Olivier Dehaene's avatar
Olivier Dehaene committed
7
use axum::extract::Extension;
8
use axum::http::{HeaderMap, StatusCode};
9
use axum::response::sse::{Event, KeepAlive, Sse};
10
use axum::response::IntoResponse;
Olivier Dehaene's avatar
Olivier Dehaene committed
11
use axum::routing::{get, post};
Olivier Dehaene's avatar
Olivier Dehaene committed
12
use axum::{Json, Router};
13
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
14
use futures::Stream;
15
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
16
use std::convert::Infallible;
Olivier Dehaene's avatar
Olivier Dehaene committed
17
use std::net::SocketAddr;
18
use text_generation_client::ShardedClient;
Olivier Dehaene's avatar
Olivier Dehaene committed
19
use tokenizers::Tokenizer;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
20
use tokio::signal;
Olivier Dehaene's avatar
Olivier Dehaene committed
21
use tokio::time::Instant;
22
use tokio_stream::StreamExt;
23
use tracing::{info_span, instrument, Instrument};
24
25
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
Olivier Dehaene's avatar
Olivier Dehaene committed
26

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
27
/// Health check method
28
29
#[instrument(skip(infer))]
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
30
31
    // TODO: while this is the best health check we can do, it is a bit on the heavy side and might
    //       be a bit too slow for a health check.
32
    //       What we should do instead is check if the gRPC channels are still healthy.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
33
34

    // Send a small inference request
35
36
37
38
    infer
        .generate(GenerateRequest {
            inputs: "liveness".to_string(),
            parameters: GenerateParameters {
39
40
41
42
                temperature: None,
                repetition_penalty: None,
                top_k: None,
                top_p: None,
43
44
                do_sample: false,
                max_new_tokens: 1,
45
                stop: Vec::new(),
46
47
                details: false,
                seed: None,
Olivier Dehaene's avatar
Olivier Dehaene committed
48
            },
49
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
50
51
        .await?;
    Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
52
53
}

54
55
56
57
58
59
60
/// Generate tokens
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/generate",
    request_body = GenerateRequest,
    responses(
61
62
        (status = 200, description = "Generated Text", body = GenerateResponse),
        (status = 424, description = "Generation Error", body = ErrorResponse,
63
            example = json!({"error": "Request failed during generation"})),
64
        (status = 429, description = "Model is overloaded", body = ErrorResponse,
65
            example = json!({"error": "Model is overloaded"})),
66
        (status = 422, description = "Input validation error", body = ErrorResponse,
67
            example = json!({"error": "Input validation error"})),
68
        (status = 500, description = "Incomplete generation", body = ErrorResponse,
69
70
71
            example = json!({"error": "Incomplete generation"})),
    )
)]
72
#[instrument(
73
    skip(infer),
74
75
76
77
78
    fields(
        total_time,
        validation_time,
        queue_time,
        inference_time,
79
        time_per_token,
80
        seed,
81
82
    )
)]
Olivier Dehaene's avatar
Olivier Dehaene committed
83
async fn generate(
84
    infer: Extension<Infer>,
Olivier Dehaene's avatar
Olivier Dehaene committed
85
    req: Json<GenerateRequest>,
86
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
87
    let span = tracing::Span::current();
88
    let start_time = Instant::now();
89
90

    // Inference
91
    let details = req.0.parameters.details;
92
    let response = infer.generate(req.0).await?;
Olivier Dehaene's avatar
Olivier Dehaene committed
93

OlivierDehaene's avatar
OlivierDehaene committed
94
95
    // Token details
    let details = match details {
96
        true => Some(Details {
97
            finish_reason: FinishReason::from(response.generated_text.finish_reason),
98
99
100
101
102
            generated_tokens: response.generated_text.generated_tokens,
            prefill: Some(response.prefill),
            tokens: Some(response.tokens),
            seed: response.generated_text.seed,
        }),
OlivierDehaene's avatar
OlivierDehaene committed
103
104
105
        false => None,
    };

106
107
108
109
    // Timings
    let total_time = start_time.elapsed();
    let validation_time = response.queued - start_time;
    let queue_time = response.start - response.queued;
110
111
    let inference_time = Instant::now() - response.start;
    let time_per_token = inference_time / response.generated_text.generated_tokens;
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    // Headers
    let mut headers = HeaderMap::new();
    headers.insert(
        "x-total-time",
        total_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-validation-time",
        validation_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-queue-time",
        queue_time.as_millis().to_string().parse().unwrap(),
Olivier Dehaene's avatar
Olivier Dehaene committed
126
    );
127
128
129
130
131
132
133
134
135
136
    headers.insert(
        "x-inference-time",
        inference_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-time-per-token",
        time_per_token.as_millis().to_string().parse().unwrap(),
    );

    // Tracing metadata
137
138
139
140
141
    span.record("total_time", format!("{total_time:?}"));
    span.record("validation_time", format!("{validation_time:?}"));
    span.record("queue_time", format!("{queue_time:?}"));
    span.record("inference_time", format!("{inference_time:?}"));
    span.record("time_per_token", format!("{time_per_token:?}"));
142
143
    span.record("seed", format!("{:?}", response.generated_text.seed));
    tracing::info!("Output: {}", response.generated_text.text);
Olivier Dehaene's avatar
Olivier Dehaene committed
144

145
146
147
148
149
150
151
152
153
154
155
156
    // Metrics
    metrics::increment_counter!("tgi_request_success");
    metrics::histogram!("tgi_request_duration", total_time);
    metrics::histogram!("tgi_request_validation_duration", validation_time);
    metrics::histogram!("tgi_request_queue_duration", queue_time);
    metrics::histogram!("tgi_request_inference_duration", inference_time);
    metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
    metrics::histogram!(
        "tgi_request_generated_tokens",
        response.generated_text.generated_tokens as f64
    );

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
157
    // Send response
158
    let response = GenerateResponse {
159
        generated_text: response.generated_text.text,
OlivierDehaene's avatar
OlivierDehaene committed
160
        details,
161
    };
162
    Ok((headers, Json(response)))
Olivier Dehaene's avatar
Olivier Dehaene committed
163
164
}

Yannic Kilcher's avatar
Yannic Kilcher committed
165
/// Generate a stream of token using Server-Sent Events
166
167
168
169
170
171
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/generate_stream",
    request_body = GenerateRequest,
    responses(
172
173
174
        (status = 200, description = "Generated Text", body = StreamResponse,
            content_type="text/event-stream"),
        (status = 424, description = "Generation Error", body = ErrorResponse,
175
            example = json!({"error": "Request failed during generation"}),
176
177
            content_type="text/event-stream"),
        (status = 429, description = "Model is overloaded", body = ErrorResponse,
178
            example = json!({"error": "Model is overloaded"}),
179
180
            content_type="text/event-stream"),
        (status = 422, description = "Input validation error", body = ErrorResponse,
181
            example = json!({"error": "Input validation error"}),
182
183
            content_type="text/event-stream"),
        (status = 500, description = "Incomplete generation", body = ErrorResponse,
184
            example = json!({"error": "Incomplete generation"}),
185
            content_type="text/event-stream"),
186
187
    )
)]
188
189
190
191
192
193
194
#[instrument(
    skip(infer),
    fields(
        total_time,
        validation_time,
        queue_time,
        inference_time,
195
196
        time_per_token,
        seed,
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    )
)]
async fn generate_stream(
    infer: Extension<Infer>,
    req: Json<GenerateRequest>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    let span = tracing::Span::current();
    let start_time = Instant::now();

    let stream = async_stream::stream! {
        // Inference
        let mut end_reached = false;
        let mut error = false;
        let details = req.0.parameters.details;

212
        match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
213
            Ok(mut response_stream) => {
Yannic Kilcher's avatar
Yannic Kilcher committed
214
                // Server-Sent Event stream
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
                while let Some(response) = response_stream.next().await {
                    match response {
                        Ok(response) => {
                            match response {
                                // Prefill is ignored
                                InferStreamResponse::Prefill(_) => {}
                                // Yield event for every new token
                                InferStreamResponse::Token(token) => {
                                    // StreamResponse
                                    let stream_token = StreamResponse {
                                        token,
                                        generated_text: None,
                                        details: None,
                                    };

                                    yield Ok(Event::default().json_data(stream_token).unwrap())
                                }
                                // Yield event for last token and compute timings
                                InferStreamResponse::End {
                                    token,
                                    generated_text,
                                    start,
                                    queued,
                                } => {
                                    // Token details
                                    let details = match details {
241
242
                                        true => Some(StreamDetails {
                                            finish_reason: FinishReason::from(generated_text.finish_reason),
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
                                            generated_tokens: generated_text.generated_tokens,
                                            seed: generated_text.seed,
                                        }),
                                        false => None,
                                    };

                                    // Timings
                                    let total_time = start_time.elapsed();
                                    let validation_time = queued - start_time;
                                    let queue_time = start - queued;
                                    let inference_time = Instant::now() - start;
                                    let time_per_token = inference_time / generated_text.generated_tokens;

                                    // Tracing metadata
                                    span.record("total_time", format!("{:?}", total_time));
258
                                    span.record("validation_time", format!("{:?}", validation_time));
259
                                    span.record("queue_time", format!("{:?}", queue_time));
260
261
262
                                    span.record("inference_time", format!("{:?}", inference_time));
                                    span.record("time_per_token", format!("{:?}", time_per_token));
                                    span.record("seed", format!("{:?}", generated_text.seed));
263
264
                                    tracing::info!(parent: &span, "Output: {}", generated_text.text);

265
266
267
268
269
270
271
272
273
                                    // Metrics
                                    metrics::increment_counter!("tgi_request_success");
                                    metrics::histogram!("tgi_request_duration", total_time);
                                    metrics::histogram!("tgi_request_validation_duration", validation_time);
                                    metrics::histogram!("tgi_request_queue_duration", queue_time);
                                    metrics::histogram!("tgi_request_inference_duration", inference_time);
                                    metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
                                    metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);

274
275
276
277
278
279
280
281
282
283
284
285
                                    // StreamResponse
                                    end_reached = true;
                                    let stream_token = StreamResponse {
                                        token,
                                        generated_text: Some(generated_text.text),
                                        details
                                    };

                                    yield Ok(Event::default().json_data(stream_token).unwrap())
                                }
                            }
                        }
286
                        // yield error
287
288
289
290
291
292
293
                        Err(err) => {
                            error = true;
                            yield Ok(Event::from(err))
                        }
                    }
                }
            },
294
            // yield error
295
296
297
298
299
300
301
302
303
            Err(err) => {
                error = true;
                yield Ok(Event::from(err))
            }
        }
        // Check if generation reached the end
        // Skip if we already sent an error
        if !end_reached && !error {
            let err = InferError::IncompleteGeneration;
304
            metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
305
            tracing::error!("{err}");
306
307
308
309
310
311
312
            yield Ok(Event::from(err))
        }
    };

    Sse::new(stream).keep_alive(KeepAlive::default())
}

313
314
315
316
317
318
319
320
321
322
323
/// Prometheus metrics scrape endpoint
#[utoipa::path(
    get,
    tag = "Text Generation Inference",
    path = "/metrics",
    responses((status = 200, description = "Prometheus Metrics", body = String))
)]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
    prom_handle.render()
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
324
325
326
327
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
    max_concurrent_requests: usize,
328
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
329
    max_input_length: usize,
330
    max_total_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
331
    max_batch_size: usize,
332
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
333
334
335
336
337
    client: ShardedClient,
    tokenizer: Tokenizer,
    validation_workers: usize,
    addr: SocketAddr,
) {
338
339
340
341
342
343
    // OpenAPI documentation
    #[derive(OpenApi)]
    #[openapi(
        paths(
            generate,
            generate_stream,
344
            metrics,
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
        ),
        components(
            schemas(
                GenerateRequest,
                GenerateParameters,
                Token,
                GenerateResponse,
                Details,
                FinishReason,
                StreamResponse,
                StreamDetails,
                ErrorResponse,
            )
        ),
        tags(
            (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
        ),
        info(
            title = "Text Generation Inference",
            license(
                name = "Apache 2.0",
                url = "https://www.apache.org/licenses/LICENSE-2.0"
            )
        )
    )]
    struct ApiDoc;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
372
    // Create state
373
374
375
376
377
378
379
    let validation = Validation::new(
        validation_workers,
        tokenizer,
        max_stop_sequences,
        max_input_length,
        max_total_tokens,
    );
380
381
    let infer = Infer::new(
        client,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
382
        validation,
383
384
385
386
        max_batch_size,
        max_waiting_tokens,
        max_concurrent_requests,
    );
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
387

388
389
390
391
392
393
    // Prometheus handler
    let builder = PrometheusBuilder::new();
    let prom_handle = builder
        .install_recorder()
        .expect("failed to install metrics recorder");

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
394
    // Create router
Olivier Dehaene's avatar
Olivier Dehaene committed
395
    let app = Router::new()
396
        .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
397
        .route("/", post(generate))
Olivier Dehaene's avatar
Olivier Dehaene committed
398
        .route("/generate", post(generate))
399
        .route("/generate_stream", post(generate_stream))
400
        .route("/", get(health))
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
401
        .route("/health", get(health))
402
        .layer(Extension(infer))
403
404
        .route("/metrics", get(metrics))
        .layer(Extension(prom_handle))
405
        .layer(opentelemetry_tracing_layer());
Olivier Dehaene's avatar
Olivier Dehaene committed
406

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
407
    // Run server
Olivier Dehaene's avatar
Olivier Dehaene committed
408
    axum::Server::bind(&addr)
Olivier Dehaene's avatar
Olivier Dehaene committed
409
        .serve(app.into_make_service())
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
410
411
        // Wait until all requests are finished to shut down
        .with_graceful_shutdown(shutdown_signal())
Olivier Dehaene's avatar
Olivier Dehaene committed
412
413
        .await
        .unwrap();
Olivier Dehaene's avatar
Olivier Dehaene committed
414
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440

/// Shutdown signal handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {},
        _ = terminate => {},
    }

    tracing::info!("signal received, starting graceful shutdown");
441
    opentelemetry::global::shutdown_tracer_provider();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
442
}
443

444
445
446
447
448
449
450
451
452
453
454
impl From<i32> for FinishReason {
    fn from(finish_reason: i32) -> Self {
        let finish_reason = text_generation_client::FinishReason::from_i32(finish_reason).unwrap();
        match finish_reason {
            text_generation_client::FinishReason::Length => FinishReason::Length,
            text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
            text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
        }
    }
}

455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
/// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
    fn from(err: InferError) -> Self {
        let status_code = match err {
            InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
            InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
            InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
            InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
        };

        (
            status_code,
            Json(ErrorResponse {
                error: err.to_string(),
            }),
        )
    }
}

impl From<InferError> for Event {
    fn from(err: InferError) -> Self {
        Event::default()
            .json_data(ErrorResponse {
                error: err.to_string(),
            })
            .unwrap()
    }
}