server.rs 20.5 KB
Newer Older
1
2
/// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
3
use crate::{
4
5
6
    CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
    GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token,
    Validation,
7
};
Olivier Dehaene's avatar
Olivier Dehaene committed
8
use axum::extract::Extension;
9
use axum::http::{HeaderMap, Method, StatusCode};
10
use axum::response::sse::{Event, KeepAlive, Sse};
11
use axum::response::{IntoResponse, Response};
Olivier Dehaene's avatar
Olivier Dehaene committed
12
use axum::routing::{get, post};
13
use axum::{http, Json, Router};
14
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
15
use futures::Stream;
16
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
17
use std::convert::Infallible;
Olivier Dehaene's avatar
Olivier Dehaene committed
18
use std::net::SocketAddr;
19
use text_generation_client::ShardedClient;
Olivier Dehaene's avatar
Olivier Dehaene committed
20
use tokenizers::Tokenizer;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
21
use tokio::signal;
Olivier Dehaene's avatar
Olivier Dehaene committed
22
use tokio::time::Instant;
23
use tokio_stream::StreamExt;
24
use tower_http::cors::{AllowOrigin, CorsLayer};
25
use tracing::{info_span, instrument, Instrument};
26
27
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
Olivier Dehaene's avatar
Olivier Dehaene committed
28

29
30
31
/// Compatibility route with api-inference and AzureML
#[instrument(skip(infer))]
async fn compat_generate(
32
    default_return_full_text: Extension<bool>,
33
34
    infer: Extension<Infer>,
    req: Json<CompatGenerateRequest>,
35
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
36
37
38
39
40
41
42
    let mut req = req.0;

    // default return_full_text given the pipeline_tag
    if req.parameters.return_full_text.is_none() {
        req.parameters.return_full_text = Some(default_return_full_text.0)
    }

43
44
45
46
47
48
49
50
51
52
53
54
    // switch on stream
    if req.stream {
        Ok(generate_stream(infer, Json(req.into()))
            .await
            .into_response())
    } else {
        let (headers, generation) = generate(infer, Json(req.into())).await?;
        // wrap generation inside a Vec to match api-inference
        Ok((headers, Json(vec![generation.0])).into_response())
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
55
/// Health check method
56
57
#[instrument(skip(infer))]
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
58
59
    // TODO: while this is the best health check we can do, it is a bit on the heavy side and might
    //       be a bit too slow for a health check.
60
    //       What we should do instead is check if the gRPC channels are still healthy.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
61
62

    // Send a small inference request
63
64
65
66
    infer
        .generate(GenerateRequest {
            inputs: "liveness".to_string(),
            parameters: GenerateParameters {
67
68
69
70
                temperature: None,
                repetition_penalty: None,
                top_k: None,
                top_p: None,
71
72
                do_sample: false,
                max_new_tokens: 1,
73
                return_full_text: None,
74
                stop: Vec::new(),
75
                watermark: false,
76
77
                details: false,
                seed: None,
Olivier Dehaene's avatar
Olivier Dehaene committed
78
            },
79
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
80
81
        .await?;
    Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
82
83
}

84
85
86
87
88
89
90
/// Generate tokens
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/generate",
    request_body = GenerateRequest,
    responses(
91
92
        (status = 200, description = "Generated Text", body = GenerateResponse),
        (status = 424, description = "Generation Error", body = ErrorResponse,
93
            example = json ! ({"error": "Request failed during generation"})),
94
        (status = 429, description = "Model is overloaded", body = ErrorResponse,
95
            example = json ! ({"error": "Model is overloaded"})),
96
        (status = 422, description = "Input validation error", body = ErrorResponse,
97
            example = json ! ({"error": "Input validation error"})),
98
        (status = 500, description = "Incomplete generation", body = ErrorResponse,
99
            example = json ! ({"error": "Incomplete generation"})),
100
101
    )
)]
102
#[instrument(
103
    skip(infer),
104
105
106
107
108
    fields(
        total_time,
        validation_time,
        queue_time,
        inference_time,
109
        time_per_token,
110
        seed,
111
112
    )
)]
Olivier Dehaene's avatar
Olivier Dehaene committed
113
async fn generate(
114
    infer: Extension<Infer>,
Olivier Dehaene's avatar
Olivier Dehaene committed
115
    req: Json<GenerateRequest>,
116
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
117
    let span = tracing::Span::current();
118
    let start_time = Instant::now();
119

120
    let compute_characters = req.0.inputs.chars().count();
121
122
123
124
125
    let mut add_prompt = None;
    if req.0.parameters.return_full_text.unwrap_or(false) {
        add_prompt = Some(req.0.inputs.clone());
    }

126
    let details = req.0.parameters.details;
127
128

    // Inference
129
    let response = infer.generate(req.0).await?;
Olivier Dehaene's avatar
Olivier Dehaene committed
130

OlivierDehaene's avatar
OlivierDehaene committed
131
132
    // Token details
    let details = match details {
133
        true => Some(Details {
134
            finish_reason: FinishReason::from(response.generated_text.finish_reason),
135
            generated_tokens: response.generated_text.generated_tokens,
136
137
            prefill: response.prefill,
            tokens: response.tokens,
138
139
            seed: response.generated_text.seed,
        }),
OlivierDehaene's avatar
OlivierDehaene committed
140
141
142
        false => None,
    };

143
144
145
146
    // Timings
    let total_time = start_time.elapsed();
    let validation_time = response.queued - start_time;
    let queue_time = response.start - response.queued;
147
148
    let inference_time = Instant::now() - response.start;
    let time_per_token = inference_time / response.generated_text.generated_tokens;
149
150
151

    // Headers
    let mut headers = HeaderMap::new();
152
153
154
155
156
157
158
159
160
    headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
    headers.insert(
        "x-compute-time",
        total_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-compute-characters",
        compute_characters.to_string().parse().unwrap(),
    );
161
162
163
164
165
166
167
168
169
170
171
    headers.insert(
        "x-total-time",
        total_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-validation-time",
        validation_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-queue-time",
        queue_time.as_millis().to_string().parse().unwrap(),
Olivier Dehaene's avatar
Olivier Dehaene committed
172
    );
173
174
175
176
177
178
179
180
181
182
    headers.insert(
        "x-inference-time",
        inference_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-time-per-token",
        time_per_token.as_millis().to_string().parse().unwrap(),
    );

    // Tracing metadata
183
184
185
186
187
    span.record("total_time", format!("{total_time:?}"));
    span.record("validation_time", format!("{validation_time:?}"));
    span.record("queue_time", format!("{queue_time:?}"));
    span.record("inference_time", format!("{inference_time:?}"));
    span.record("time_per_token", format!("{time_per_token:?}"));
188
189
    span.record("seed", format!("{:?}", response.generated_text.seed));
    tracing::info!("Output: {}", response.generated_text.text);
Olivier Dehaene's avatar
Olivier Dehaene committed
190

191
192
193
194
195
196
197
198
199
200
201
202
    // Metrics
    metrics::increment_counter!("tgi_request_success");
    metrics::histogram!("tgi_request_duration", total_time);
    metrics::histogram!("tgi_request_validation_duration", validation_time);
    metrics::histogram!("tgi_request_queue_duration", queue_time);
    metrics::histogram!("tgi_request_inference_duration", inference_time);
    metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
    metrics::histogram!(
        "tgi_request_generated_tokens",
        response.generated_text.generated_tokens as f64
    );

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
203
    // Send response
204
205
206
207
208
    let mut output_text = response.generated_text.text;
    if let Some(prompt) = add_prompt {
        output_text = prompt + &output_text;
    }

209
    let response = GenerateResponse {
210
        generated_text: output_text,
OlivierDehaene's avatar
OlivierDehaene committed
211
        details,
212
    };
213
    Ok((headers, Json(response)))
Olivier Dehaene's avatar
Olivier Dehaene committed
214
215
}

Yannic Kilcher's avatar
Yannic Kilcher committed
216
/// Generate a stream of token using Server-Sent Events
217
218
219
220
221
222
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/generate_stream",
    request_body = GenerateRequest,
    responses(
223
        (status = 200, description = "Generated Text", body = StreamResponse,
224
            content_type = "text/event-stream"),
225
        (status = 424, description = "Generation Error", body = ErrorResponse,
226
227
            example = json ! ({"error": "Request failed during generation"}),
            content_type = "text/event-stream"),
228
        (status = 429, description = "Model is overloaded", body = ErrorResponse,
229
230
            example = json ! ({"error": "Model is overloaded"}),
            content_type = "text/event-stream"),
231
        (status = 422, description = "Input validation error", body = ErrorResponse,
232
233
            example = json ! ({"error": "Input validation error"}),
            content_type = "text/event-stream"),
234
        (status = 500, description = "Incomplete generation", body = ErrorResponse,
235
236
            example = json ! ({"error": "Incomplete generation"}),
            content_type = "text/event-stream"),
237
238
    )
)]
239
240
241
242
243
244
245
#[instrument(
    skip(infer),
    fields(
        total_time,
        validation_time,
        queue_time,
        inference_time,
246
247
        time_per_token,
        seed,
248
249
250
251
252
    )
)]
async fn generate_stream(
    infer: Extension<Infer>,
    req: Json<GenerateRequest>,
253
254
255
256
) -> (
    HeaderMap,
    Sse<impl Stream<Item = Result<Event, Infallible>>>,
) {
257
258
259
    let span = tracing::Span::current();
    let start_time = Instant::now();

260
261
262
263
264
265
266
267
268
    let compute_characters = req.0.inputs.chars().count();

    let mut headers = HeaderMap::new();
    headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
    headers.insert(
        "x-compute-characters",
        compute_characters.to_string().parse().unwrap(),
    );

269
270
271
272
    let stream = async_stream::stream! {
        // Inference
        let mut end_reached = false;
        let mut error = false;
273
274
275
276
277

        let mut add_prompt = None;
        if req.0.parameters.return_full_text.unwrap_or(false) {
            add_prompt = Some(req.0.inputs.clone());
        }
278
279
        let details = req.0.parameters.details;

280
        match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
281
            Ok(mut response_stream) => {
Yannic Kilcher's avatar
Yannic Kilcher committed
282
                // Server-Sent Event stream
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
                while let Some(response) = response_stream.next().await {
                    match response {
                        Ok(response) => {
                            match response {
                                // Prefill is ignored
                                InferStreamResponse::Prefill(_) => {}
                                // Yield event for every new token
                                InferStreamResponse::Token(token) => {
                                    // StreamResponse
                                    let stream_token = StreamResponse {
                                        token,
                                        generated_text: None,
                                        details: None,
                                    };

                                    yield Ok(Event::default().json_data(stream_token).unwrap())
                                }
                                // Yield event for last token and compute timings
                                InferStreamResponse::End {
                                    token,
                                    generated_text,
                                    start,
                                    queued,
                                } => {
                                    // Token details
                                    let details = match details {
309
310
                                        true => Some(StreamDetails {
                                            finish_reason: FinishReason::from(generated_text.finish_reason),
311
312
313
314
315
316
317
318
319
320
321
322
323
324
                                            generated_tokens: generated_text.generated_tokens,
                                            seed: generated_text.seed,
                                        }),
                                        false => None,
                                    };

                                    // Timings
                                    let total_time = start_time.elapsed();
                                    let validation_time = queued - start_time;
                                    let queue_time = start - queued;
                                    let inference_time = Instant::now() - start;
                                    let time_per_token = inference_time / generated_text.generated_tokens;

                                    // Tracing metadata
325
326
327
328
329
                                    span.record("total_time", format!("{total_time:?}"));
                                    span.record("validation_time", format!("{validation_time:?}"));
                                    span.record("queue_time", format!("{queue_time:?}"));
                                    span.record("inference_time", format!("{inference_time:?}"));
                                    span.record("time_per_token", format!("{time_per_token:?}"));
330
                                    span.record("seed", format!("{:?}", generated_text.seed));
331
332
                                    tracing::info!(parent: &span, "Output: {}", generated_text.text);

333
334
335
336
337
338
339
340
341
                                    // Metrics
                                    metrics::increment_counter!("tgi_request_success");
                                    metrics::histogram!("tgi_request_duration", total_time);
                                    metrics::histogram!("tgi_request_validation_duration", validation_time);
                                    metrics::histogram!("tgi_request_queue_duration", queue_time);
                                    metrics::histogram!("tgi_request_inference_duration", inference_time);
                                    metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
                                    metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);

342
343
                                    // StreamResponse
                                    end_reached = true;
344
345
346
347
348
349

                                    let mut output_text = generated_text.text;
                                    if let Some(prompt) = add_prompt {
                                        output_text = prompt + &output_text;
                                    }

350
351
                                    let stream_token = StreamResponse {
                                        token,
352
                                        generated_text: Some(output_text),
353
354
355
                                        details
                                    };

356
357
                                    yield Ok(Event::default().json_data(stream_token).unwrap());
                                    break;
358
359
360
                                }
                            }
                        }
361
                        // yield error
362
363
                        Err(err) => {
                            error = true;
364
365
                            yield Ok(Event::from(err));
                            break;
366
367
368
369
                        }
                    }
                }
            },
370
            // yield error
371
372
            Err(err) => {
                error = true;
373
                yield Ok(Event::from(err));
374
375
376
377
378
379
            }
        }
        // Check if generation reached the end
        // Skip if we already sent an error
        if !end_reached && !error {
            let err = InferError::IncompleteGeneration;
380
            metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
381
            tracing::error!("{err}");
382
            yield Ok(Event::from(err));
383
384
385
        }
    };

386
    (headers, Sse::new(stream).keep_alive(KeepAlive::default()))
387
388
}

389
390
391
392
393
394
395
396
397
398
399
/// Prometheus metrics scrape endpoint
#[utoipa::path(
    get,
    tag = "Text Generation Inference",
    path = "/metrics",
    responses((status = 200, description = "Prometheus Metrics", body = String))
)]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
    prom_handle.render()
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
400
401
402
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
403
    compat_return_full_text: bool,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
404
    max_concurrent_requests: usize,
405
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
406
    max_input_length: usize,
407
    max_total_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
408
    max_batch_size: usize,
409
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
410
411
412
413
    client: ShardedClient,
    tokenizer: Tokenizer,
    validation_workers: usize,
    addr: SocketAddr,
414
    allow_origin: Option<AllowOrigin>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
415
) {
416
417
418
419
420
421
    // OpenAPI documentation
    #[derive(OpenApi)]
    #[openapi(
        paths(
            generate,
            generate_stream,
422
            metrics,
423
424
425
426
427
        ),
        components(
            schemas(
                GenerateRequest,
                GenerateParameters,
428
                PrefillToken,
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
                Token,
                GenerateResponse,
                Details,
                FinishReason,
                StreamResponse,
                StreamDetails,
                ErrorResponse,
            )
        ),
        tags(
            (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
        ),
        info(
            title = "Text Generation Inference",
            license(
                name = "Apache 2.0",
                url = "https://www.apache.org/licenses/LICENSE-2.0"
            )
        )
    )]
    struct ApiDoc;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
451
    // Create state
452
453
454
455
456
457
458
    let validation = Validation::new(
        validation_workers,
        tokenizer,
        max_stop_sequences,
        max_input_length,
        max_total_tokens,
    );
459
460
    let infer = Infer::new(
        client,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
461
        validation,
462
463
464
465
        max_batch_size,
        max_waiting_tokens,
        max_concurrent_requests,
    );
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
466

467
468
469
470
471
472
    // Prometheus handler
    let builder = PrometheusBuilder::new();
    let prom_handle = builder
        .install_recorder()
        .expect("failed to install metrics recorder");

473
474
475
476
477
478
479
    // CORS layer
    let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
    let cors_layer = CorsLayer::new()
        .allow_methods([Method::GET, Method::POST])
        .allow_headers([http::header::CONTENT_TYPE])
        .allow_origin(allow_origin);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
480
    // Create router
Olivier Dehaene's avatar
Olivier Dehaene committed
481
    let app = Router::new()
482
        .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
483
        .route("/", post(compat_generate))
Olivier Dehaene's avatar
Olivier Dehaene committed
484
        .route("/generate", post(generate))
485
        .route("/generate_stream", post(generate_stream))
486
        .route("/", get(health))
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
487
        .route("/health", get(health))
488
        .route("/metrics", get(metrics))
489
490
        .layer(Extension(compat_return_full_text))
        .layer(Extension(infer))
491
        .layer(Extension(prom_handle))
492
493
        .layer(opentelemetry_tracing_layer())
        .layer(cors_layer);
Olivier Dehaene's avatar
Olivier Dehaene committed
494

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
495
    // Run server
Olivier Dehaene's avatar
Olivier Dehaene committed
496
    axum::Server::bind(&addr)
Olivier Dehaene's avatar
Olivier Dehaene committed
497
        .serve(app.into_make_service())
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
498
499
        // Wait until all requests are finished to shut down
        .with_graceful_shutdown(shutdown_signal())
Olivier Dehaene's avatar
Olivier Dehaene committed
500
501
        .await
        .unwrap();
Olivier Dehaene's avatar
Olivier Dehaene committed
502
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528

/// Shutdown signal handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {},
        _ = terminate => {},
    }

    tracing::info!("signal received, starting graceful shutdown");
529
    opentelemetry::global::shutdown_tracer_provider();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
530
}
531

532
533
534
535
536
537
538
539
540
541
542
impl From<i32> for FinishReason {
    fn from(finish_reason: i32) -> Self {
        let finish_reason = text_generation_client::FinishReason::from_i32(finish_reason).unwrap();
        match finish_reason {
            text_generation_client::FinishReason::Length => FinishReason::Length,
            text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
            text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
        }
    }
}

543
544
545
546
547
548
549
550
551
552
553
554
555
556
/// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
    fn from(err: InferError) -> Self {
        let status_code = match err {
            InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
            InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
            InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
            InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
        };

        (
            status_code,
            Json(ErrorResponse {
                error: err.to_string(),
557
                error_type: err.error_type().to_string(),
558
559
560
561
562
563
564
565
566
567
            }),
        )
    }
}

impl From<InferError> for Event {
    fn from(err: InferError) -> Self {
        Event::default()
            .json_data(ErrorResponse {
                error: err.to_string(),
568
                error_type: err.error_type().to_string(),
569
570
571
572
            })
            .unwrap()
    }
}