server.rs 20.6 KB
Newer Older
1
2
/// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
3
use crate::{
4
5
6
    CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
    GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token,
    Validation,
7
};
Olivier Dehaene's avatar
Olivier Dehaene committed
8
use axum::extract::Extension;
9
use axum::http::{HeaderMap, Method, StatusCode};
10
use axum::response::sse::{Event, KeepAlive, Sse};
11
use axum::response::{IntoResponse, Response};
Olivier Dehaene's avatar
Olivier Dehaene committed
12
use axum::routing::{get, post};
13
use axum::{http, Json, Router};
14
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
15
use futures::Stream;
16
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
17
use std::convert::Infallible;
Olivier Dehaene's avatar
Olivier Dehaene committed
18
use std::net::SocketAddr;
19
use text_generation_client::ShardedClient;
Olivier Dehaene's avatar
Olivier Dehaene committed
20
use tokenizers::Tokenizer;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
21
use tokio::signal;
Olivier Dehaene's avatar
Olivier Dehaene committed
22
use tokio::time::Instant;
23
use tokio_stream::StreamExt;
24
use tower_http::cors::{AllowOrigin, CorsLayer};
25
use tracing::{info_span, instrument, Instrument};
26
27
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
Olivier Dehaene's avatar
Olivier Dehaene committed
28

29
30
31
/// Compatibility route with api-inference and AzureML
#[instrument(skip(infer))]
async fn compat_generate(
32
    default_return_full_text: Extension<bool>,
33
34
    infer: Extension<Infer>,
    req: Json<CompatGenerateRequest>,
35
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
36
37
38
39
40
41
42
    let mut req = req.0;

    // default return_full_text given the pipeline_tag
    if req.parameters.return_full_text.is_none() {
        req.parameters.return_full_text = Some(default_return_full_text.0)
    }

43
44
45
46
47
48
49
50
51
52
53
54
    // switch on stream
    if req.stream {
        Ok(generate_stream(infer, Json(req.into()))
            .await
            .into_response())
    } else {
        let (headers, generation) = generate(infer, Json(req.into())).await?;
        // wrap generation inside a Vec to match api-inference
        Ok((headers, Json(vec![generation.0])).into_response())
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
55
/// Health check method
56
57
#[instrument(skip(infer))]
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
58
59
    // TODO: while this is the best health check we can do, it is a bit on the heavy side and might
    //       be a bit too slow for a health check.
60
    //       What we should do instead is check if the gRPC channels are still healthy.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
61
62

    // Send a small inference request
63
64
65
66
    infer
        .generate(GenerateRequest {
            inputs: "liveness".to_string(),
            parameters: GenerateParameters {
67
68
69
70
                temperature: None,
                repetition_penalty: None,
                top_k: None,
                top_p: None,
71
                typical_p: None,
72
73
                do_sample: false,
                max_new_tokens: 1,
74
                return_full_text: None,
75
                stop: Vec::new(),
76
                truncate: None,
77
                watermark: false,
78
79
                details: false,
                seed: None,
Olivier Dehaene's avatar
Olivier Dehaene committed
80
            },
81
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
82
83
        .await?;
    Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
84
85
}

86
87
88
89
90
91
92
/// Generate tokens
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/generate",
    request_body = GenerateRequest,
    responses(
93
94
        (status = 200, description = "Generated Text", body = GenerateResponse),
        (status = 424, description = "Generation Error", body = ErrorResponse,
95
            example = json ! ({"error": "Request failed during generation"})),
96
        (status = 429, description = "Model is overloaded", body = ErrorResponse,
97
            example = json ! ({"error": "Model is overloaded"})),
98
        (status = 422, description = "Input validation error", body = ErrorResponse,
99
            example = json ! ({"error": "Input validation error"})),
100
        (status = 500, description = "Incomplete generation", body = ErrorResponse,
101
            example = json ! ({"error": "Incomplete generation"})),
102
103
    )
)]
104
#[instrument(
105
    skip(infer),
106
107
108
109
110
    fields(
        total_time,
        validation_time,
        queue_time,
        inference_time,
111
        time_per_token,
112
        seed,
113
114
    )
)]
Olivier Dehaene's avatar
Olivier Dehaene committed
115
async fn generate(
116
    infer: Extension<Infer>,
Olivier Dehaene's avatar
Olivier Dehaene committed
117
    req: Json<GenerateRequest>,
118
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
119
    let span = tracing::Span::current();
120
    let start_time = Instant::now();
121

122
    let compute_characters = req.0.inputs.chars().count();
123
124
125
126
127
    let mut add_prompt = None;
    if req.0.parameters.return_full_text.unwrap_or(false) {
        add_prompt = Some(req.0.inputs.clone());
    }

128
    let details = req.0.parameters.details;
129
130

    // Inference
131
    let response = infer.generate(req.0).await?;
Olivier Dehaene's avatar
Olivier Dehaene committed
132

OlivierDehaene's avatar
OlivierDehaene committed
133
134
    // Token details
    let details = match details {
135
        true => Some(Details {
136
            finish_reason: FinishReason::from(response.generated_text.finish_reason),
137
            generated_tokens: response.generated_text.generated_tokens,
138
139
            prefill: response.prefill,
            tokens: response.tokens,
140
141
            seed: response.generated_text.seed,
        }),
OlivierDehaene's avatar
OlivierDehaene committed
142
143
144
        false => None,
    };

145
146
147
148
    // Timings
    let total_time = start_time.elapsed();
    let validation_time = response.queued - start_time;
    let queue_time = response.start - response.queued;
149
150
    let inference_time = Instant::now() - response.start;
    let time_per_token = inference_time / response.generated_text.generated_tokens;
151
152
153

    // Headers
    let mut headers = HeaderMap::new();
154
155
156
157
158
159
160
161
162
    headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
    headers.insert(
        "x-compute-time",
        total_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-compute-characters",
        compute_characters.to_string().parse().unwrap(),
    );
163
164
165
166
167
168
169
170
171
172
173
    headers.insert(
        "x-total-time",
        total_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-validation-time",
        validation_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-queue-time",
        queue_time.as_millis().to_string().parse().unwrap(),
Olivier Dehaene's avatar
Olivier Dehaene committed
174
    );
175
176
177
178
179
180
181
182
183
184
    headers.insert(
        "x-inference-time",
        inference_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-time-per-token",
        time_per_token.as_millis().to_string().parse().unwrap(),
    );

    // Tracing metadata
185
186
187
188
189
    span.record("total_time", format!("{total_time:?}"));
    span.record("validation_time", format!("{validation_time:?}"));
    span.record("queue_time", format!("{queue_time:?}"));
    span.record("inference_time", format!("{inference_time:?}"));
    span.record("time_per_token", format!("{time_per_token:?}"));
190
191
    span.record("seed", format!("{:?}", response.generated_text.seed));
    tracing::info!("Output: {}", response.generated_text.text);
Olivier Dehaene's avatar
Olivier Dehaene committed
192

193
194
195
196
197
198
199
200
201
202
203
204
    // Metrics
    metrics::increment_counter!("tgi_request_success");
    metrics::histogram!("tgi_request_duration", total_time);
    metrics::histogram!("tgi_request_validation_duration", validation_time);
    metrics::histogram!("tgi_request_queue_duration", queue_time);
    metrics::histogram!("tgi_request_inference_duration", inference_time);
    metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
    metrics::histogram!(
        "tgi_request_generated_tokens",
        response.generated_text.generated_tokens as f64
    );

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
205
    // Send response
206
207
208
209
210
    let mut output_text = response.generated_text.text;
    if let Some(prompt) = add_prompt {
        output_text = prompt + &output_text;
    }

211
    let response = GenerateResponse {
212
        generated_text: output_text,
OlivierDehaene's avatar
OlivierDehaene committed
213
        details,
214
    };
215
    Ok((headers, Json(response)))
Olivier Dehaene's avatar
Olivier Dehaene committed
216
217
}

Yannic Kilcher's avatar
Yannic Kilcher committed
218
/// Generate a stream of token using Server-Sent Events
219
220
221
222
223
224
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/generate_stream",
    request_body = GenerateRequest,
    responses(
225
        (status = 200, description = "Generated Text", body = StreamResponse,
226
            content_type = "text/event-stream"),
227
        (status = 424, description = "Generation Error", body = ErrorResponse,
228
229
            example = json ! ({"error": "Request failed during generation"}),
            content_type = "text/event-stream"),
230
        (status = 429, description = "Model is overloaded", body = ErrorResponse,
231
232
            example = json ! ({"error": "Model is overloaded"}),
            content_type = "text/event-stream"),
233
        (status = 422, description = "Input validation error", body = ErrorResponse,
234
235
            example = json ! ({"error": "Input validation error"}),
            content_type = "text/event-stream"),
236
        (status = 500, description = "Incomplete generation", body = ErrorResponse,
237
238
            example = json ! ({"error": "Incomplete generation"}),
            content_type = "text/event-stream"),
239
240
    )
)]
241
242
243
244
245
246
247
#[instrument(
    skip(infer),
    fields(
        total_time,
        validation_time,
        queue_time,
        inference_time,
248
249
        time_per_token,
        seed,
250
251
252
253
254
    )
)]
async fn generate_stream(
    infer: Extension<Infer>,
    req: Json<GenerateRequest>,
255
256
257
258
) -> (
    HeaderMap,
    Sse<impl Stream<Item = Result<Event, Infallible>>>,
) {
259
260
261
    let span = tracing::Span::current();
    let start_time = Instant::now();

262
263
264
265
266
267
268
269
270
    let compute_characters = req.0.inputs.chars().count();

    let mut headers = HeaderMap::new();
    headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
    headers.insert(
        "x-compute-characters",
        compute_characters.to_string().parse().unwrap(),
    );

271
272
273
274
    let stream = async_stream::stream! {
        // Inference
        let mut end_reached = false;
        let mut error = false;
275
276
277
278
279

        let mut add_prompt = None;
        if req.0.parameters.return_full_text.unwrap_or(false) {
            add_prompt = Some(req.0.inputs.clone());
        }
280
281
        let details = req.0.parameters.details;

282
        match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
283
            Ok(mut response_stream) => {
Yannic Kilcher's avatar
Yannic Kilcher committed
284
                // Server-Sent Event stream
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
                while let Some(response) = response_stream.next().await {
                    match response {
                        Ok(response) => {
                            match response {
                                // Prefill is ignored
                                InferStreamResponse::Prefill(_) => {}
                                // Yield event for every new token
                                InferStreamResponse::Token(token) => {
                                    // StreamResponse
                                    let stream_token = StreamResponse {
                                        token,
                                        generated_text: None,
                                        details: None,
                                    };

                                    yield Ok(Event::default().json_data(stream_token).unwrap())
                                }
                                // Yield event for last token and compute timings
                                InferStreamResponse::End {
                                    token,
                                    generated_text,
                                    start,
                                    queued,
                                } => {
                                    // Token details
                                    let details = match details {
311
312
                                        true => Some(StreamDetails {
                                            finish_reason: FinishReason::from(generated_text.finish_reason),
313
314
315
316
317
318
319
320
321
322
323
324
325
326
                                            generated_tokens: generated_text.generated_tokens,
                                            seed: generated_text.seed,
                                        }),
                                        false => None,
                                    };

                                    // Timings
                                    let total_time = start_time.elapsed();
                                    let validation_time = queued - start_time;
                                    let queue_time = start - queued;
                                    let inference_time = Instant::now() - start;
                                    let time_per_token = inference_time / generated_text.generated_tokens;

                                    // Tracing metadata
327
328
329
330
331
                                    span.record("total_time", format!("{total_time:?}"));
                                    span.record("validation_time", format!("{validation_time:?}"));
                                    span.record("queue_time", format!("{queue_time:?}"));
                                    span.record("inference_time", format!("{inference_time:?}"));
                                    span.record("time_per_token", format!("{time_per_token:?}"));
332
                                    span.record("seed", format!("{:?}", generated_text.seed));
333
334
                                    tracing::info!(parent: &span, "Output: {}", generated_text.text);

335
336
337
338
339
340
341
342
343
                                    // Metrics
                                    metrics::increment_counter!("tgi_request_success");
                                    metrics::histogram!("tgi_request_duration", total_time);
                                    metrics::histogram!("tgi_request_validation_duration", validation_time);
                                    metrics::histogram!("tgi_request_queue_duration", queue_time);
                                    metrics::histogram!("tgi_request_inference_duration", inference_time);
                                    metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
                                    metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);

344
345
                                    // StreamResponse
                                    end_reached = true;
346
347
348
349
350
351

                                    let mut output_text = generated_text.text;
                                    if let Some(prompt) = add_prompt {
                                        output_text = prompt + &output_text;
                                    }

352
353
                                    let stream_token = StreamResponse {
                                        token,
354
                                        generated_text: Some(output_text),
355
356
357
                                        details
                                    };

358
359
                                    yield Ok(Event::default().json_data(stream_token).unwrap());
                                    break;
360
361
362
                                }
                            }
                        }
363
                        // yield error
364
365
                        Err(err) => {
                            error = true;
366
367
                            yield Ok(Event::from(err));
                            break;
368
369
370
371
                        }
                    }
                }
            },
372
            // yield error
373
374
            Err(err) => {
                error = true;
375
                yield Ok(Event::from(err));
376
377
378
379
380
381
            }
        }
        // Check if generation reached the end
        // Skip if we already sent an error
        if !end_reached && !error {
            let err = InferError::IncompleteGeneration;
382
            metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
383
            tracing::error!("{err}");
384
            yield Ok(Event::from(err));
385
386
387
        }
    };

388
    (headers, Sse::new(stream).keep_alive(KeepAlive::default()))
389
390
}

391
392
393
394
395
396
397
398
399
400
401
/// Prometheus metrics scrape endpoint
#[utoipa::path(
    get,
    tag = "Text Generation Inference",
    path = "/metrics",
    responses((status = 200, description = "Prometheus Metrics", body = String))
)]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
    prom_handle.render()
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
402
403
404
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
405
    compat_return_full_text: bool,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
406
    max_concurrent_requests: usize,
407
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
408
    max_input_length: usize,
409
    max_total_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
410
    max_batch_size: usize,
411
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
412
413
414
415
    client: ShardedClient,
    tokenizer: Tokenizer,
    validation_workers: usize,
    addr: SocketAddr,
416
    allow_origin: Option<AllowOrigin>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
417
) {
418
419
420
421
422
423
    // OpenAPI documentation
    #[derive(OpenApi)]
    #[openapi(
        paths(
            generate,
            generate_stream,
424
            metrics,
425
426
427
428
429
        ),
        components(
            schemas(
                GenerateRequest,
                GenerateParameters,
430
                PrefillToken,
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
                Token,
                GenerateResponse,
                Details,
                FinishReason,
                StreamResponse,
                StreamDetails,
                ErrorResponse,
            )
        ),
        tags(
            (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
        ),
        info(
            title = "Text Generation Inference",
            license(
                name = "Apache 2.0",
                url = "https://www.apache.org/licenses/LICENSE-2.0"
            )
        )
    )]
    struct ApiDoc;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
453
    // Create state
454
455
456
457
458
459
460
    let validation = Validation::new(
        validation_workers,
        tokenizer,
        max_stop_sequences,
        max_input_length,
        max_total_tokens,
    );
461
462
    let infer = Infer::new(
        client,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
463
        validation,
464
465
466
467
        max_batch_size,
        max_waiting_tokens,
        max_concurrent_requests,
    );
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
468

469
470
471
472
473
474
    // Prometheus handler
    let builder = PrometheusBuilder::new();
    let prom_handle = builder
        .install_recorder()
        .expect("failed to install metrics recorder");

475
476
477
478
479
480
481
    // CORS layer
    let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
    let cors_layer = CorsLayer::new()
        .allow_methods([Method::GET, Method::POST])
        .allow_headers([http::header::CONTENT_TYPE])
        .allow_origin(allow_origin);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
482
    // Create router
Olivier Dehaene's avatar
Olivier Dehaene committed
483
    let app = Router::new()
484
        .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
485
        .route("/", post(compat_generate))
Olivier Dehaene's avatar
Olivier Dehaene committed
486
        .route("/generate", post(generate))
487
        .route("/generate_stream", post(generate_stream))
488
        .route("/", get(health))
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
489
        .route("/health", get(health))
490
        .route("/metrics", get(metrics))
491
492
        .layer(Extension(compat_return_full_text))
        .layer(Extension(infer))
493
        .layer(Extension(prom_handle))
494
495
        .layer(opentelemetry_tracing_layer())
        .layer(cors_layer);
Olivier Dehaene's avatar
Olivier Dehaene committed
496

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
497
    // Run server
Olivier Dehaene's avatar
Olivier Dehaene committed
498
    axum::Server::bind(&addr)
Olivier Dehaene's avatar
Olivier Dehaene committed
499
        .serve(app.into_make_service())
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
500
501
        // Wait until all requests are finished to shut down
        .with_graceful_shutdown(shutdown_signal())
Olivier Dehaene's avatar
Olivier Dehaene committed
502
503
        .await
        .unwrap();
Olivier Dehaene's avatar
Olivier Dehaene committed
504
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

/// Shutdown signal handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {},
        _ = terminate => {},
    }

    tracing::info!("signal received, starting graceful shutdown");
531
    opentelemetry::global::shutdown_tracer_provider();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
532
}
533

534
535
536
537
538
539
540
541
542
543
544
impl From<i32> for FinishReason {
    fn from(finish_reason: i32) -> Self {
        let finish_reason = text_generation_client::FinishReason::from_i32(finish_reason).unwrap();
        match finish_reason {
            text_generation_client::FinishReason::Length => FinishReason::Length,
            text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
            text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
        }
    }
}

545
546
547
548
549
550
551
552
553
554
555
556
557
558
/// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
    fn from(err: InferError) -> Self {
        let status_code = match err {
            InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
            InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
            InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
            InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
        };

        (
            status_code,
            Json(ErrorResponse {
                error: err.to_string(),
559
                error_type: err.error_type().to_string(),
560
561
562
563
564
565
566
567
568
569
            }),
        )
    }
}

impl From<InferError> for Event {
    fn from(err: InferError) -> Self {
        Event::default()
            .json_data(ErrorResponse {
                error: err.to_string(),
570
                error_type: err.error_type().to_string(),
571
572
573
574
            })
            .unwrap()
    }
}