server.rs 20.5 KB
Newer Older
1
2
/// HTTP Server logic
use crate::infer::{InferError, InferStreamResponse};
3
use crate::{
4
5
6
    CompatGenerateRequest, Details, ErrorResponse, FinishReason, GenerateParameters,
    GenerateRequest, GenerateResponse, Infer, PrefillToken, StreamDetails, StreamResponse, Token,
    Validation,
7
};
Olivier Dehaene's avatar
Olivier Dehaene committed
8
use axum::extract::Extension;
9
use axum::http::{HeaderMap, Method, StatusCode};
10
use axum::response::sse::{Event, KeepAlive, Sse};
11
use axum::response::{IntoResponse, Response};
Olivier Dehaene's avatar
Olivier Dehaene committed
12
use axum::routing::{get, post};
13
use axum::{http, Json, Router};
14
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
15
use futures::Stream;
16
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
17
use std::convert::Infallible;
Olivier Dehaene's avatar
Olivier Dehaene committed
18
use std::net::SocketAddr;
19
use text_generation_client::ShardedClient;
Olivier Dehaene's avatar
Olivier Dehaene committed
20
use tokenizers::Tokenizer;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
21
use tokio::signal;
Olivier Dehaene's avatar
Olivier Dehaene committed
22
use tokio::time::Instant;
23
use tokio_stream::StreamExt;
24
use tower_http::cors::{AllowOrigin, CorsLayer};
25
use tracing::{info_span, instrument, Instrument};
26
27
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
Olivier Dehaene's avatar
Olivier Dehaene committed
28

29
30
31
/// Compatibility route with api-inference and AzureML
#[instrument(skip(infer))]
async fn compat_generate(
32
    default_return_full_text: Extension<bool>,
33
34
    infer: Extension<Infer>,
    req: Json<CompatGenerateRequest>,
35
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
36
37
38
39
40
41
42
    let mut req = req.0;

    // default return_full_text given the pipeline_tag
    if req.parameters.return_full_text.is_none() {
        req.parameters.return_full_text = Some(default_return_full_text.0)
    }

43
44
45
46
47
48
49
50
51
52
53
54
    // switch on stream
    if req.stream {
        Ok(generate_stream(infer, Json(req.into()))
            .await
            .into_response())
    } else {
        let (headers, generation) = generate(infer, Json(req.into())).await?;
        // wrap generation inside a Vec to match api-inference
        Ok((headers, Json(vec![generation.0])).into_response())
    }
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
55
/// Health check method
56
57
#[instrument(skip(infer))]
async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
58
59
    // TODO: while this is the best health check we can do, it is a bit on the heavy side and might
    //       be a bit too slow for a health check.
60
    //       What we should do instead is check if the gRPC channels are still healthy.
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
61
62

    // Send a small inference request
63
64
65
66
    infer
        .generate(GenerateRequest {
            inputs: "liveness".to_string(),
            parameters: GenerateParameters {
67
68
69
70
                temperature: None,
                repetition_penalty: None,
                top_k: None,
                top_p: None,
71
                typical_p: None,
72
73
                do_sample: false,
                max_new_tokens: 1,
74
                return_full_text: None,
75
                stop: Vec::new(),
76
                watermark: false,
77
78
                details: false,
                seed: None,
Olivier Dehaene's avatar
Olivier Dehaene committed
79
            },
80
        })
Olivier Dehaene's avatar
Olivier Dehaene committed
81
82
        .await?;
    Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
83
84
}

85
86
87
88
89
90
91
/// Generate tokens
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/generate",
    request_body = GenerateRequest,
    responses(
92
93
        (status = 200, description = "Generated Text", body = GenerateResponse),
        (status = 424, description = "Generation Error", body = ErrorResponse,
94
            example = json ! ({"error": "Request failed during generation"})),
95
        (status = 429, description = "Model is overloaded", body = ErrorResponse,
96
            example = json ! ({"error": "Model is overloaded"})),
97
        (status = 422, description = "Input validation error", body = ErrorResponse,
98
            example = json ! ({"error": "Input validation error"})),
99
        (status = 500, description = "Incomplete generation", body = ErrorResponse,
100
            example = json ! ({"error": "Incomplete generation"})),
101
102
    )
)]
103
#[instrument(
104
    skip(infer),
105
106
107
108
109
    fields(
        total_time,
        validation_time,
        queue_time,
        inference_time,
110
        time_per_token,
111
        seed,
112
113
    )
)]
Olivier Dehaene's avatar
Olivier Dehaene committed
114
async fn generate(
115
    infer: Extension<Infer>,
Olivier Dehaene's avatar
Olivier Dehaene committed
116
    req: Json<GenerateRequest>,
117
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
118
    let span = tracing::Span::current();
119
    let start_time = Instant::now();
120

121
    let compute_characters = req.0.inputs.chars().count();
122
123
124
125
126
    let mut add_prompt = None;
    if req.0.parameters.return_full_text.unwrap_or(false) {
        add_prompt = Some(req.0.inputs.clone());
    }

127
    let details = req.0.parameters.details;
128
129

    // Inference
130
    let response = infer.generate(req.0).await?;
Olivier Dehaene's avatar
Olivier Dehaene committed
131

OlivierDehaene's avatar
OlivierDehaene committed
132
133
    // Token details
    let details = match details {
134
        true => Some(Details {
135
            finish_reason: FinishReason::from(response.generated_text.finish_reason),
136
            generated_tokens: response.generated_text.generated_tokens,
137
138
            prefill: response.prefill,
            tokens: response.tokens,
139
140
            seed: response.generated_text.seed,
        }),
OlivierDehaene's avatar
OlivierDehaene committed
141
142
143
        false => None,
    };

144
145
146
147
    // Timings
    let total_time = start_time.elapsed();
    let validation_time = response.queued - start_time;
    let queue_time = response.start - response.queued;
148
149
    let inference_time = Instant::now() - response.start;
    let time_per_token = inference_time / response.generated_text.generated_tokens;
150
151
152

    // Headers
    let mut headers = HeaderMap::new();
153
154
155
156
157
158
159
160
161
    headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
    headers.insert(
        "x-compute-time",
        total_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-compute-characters",
        compute_characters.to_string().parse().unwrap(),
    );
162
163
164
165
166
167
168
169
170
171
172
    headers.insert(
        "x-total-time",
        total_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-validation-time",
        validation_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-queue-time",
        queue_time.as_millis().to_string().parse().unwrap(),
Olivier Dehaene's avatar
Olivier Dehaene committed
173
    );
174
175
176
177
178
179
180
181
182
183
    headers.insert(
        "x-inference-time",
        inference_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-time-per-token",
        time_per_token.as_millis().to_string().parse().unwrap(),
    );

    // Tracing metadata
184
185
186
187
188
    span.record("total_time", format!("{total_time:?}"));
    span.record("validation_time", format!("{validation_time:?}"));
    span.record("queue_time", format!("{queue_time:?}"));
    span.record("inference_time", format!("{inference_time:?}"));
    span.record("time_per_token", format!("{time_per_token:?}"));
189
190
    span.record("seed", format!("{:?}", response.generated_text.seed));
    tracing::info!("Output: {}", response.generated_text.text);
Olivier Dehaene's avatar
Olivier Dehaene committed
191

192
193
194
195
196
197
198
199
200
201
202
203
    // Metrics
    metrics::increment_counter!("tgi_request_success");
    metrics::histogram!("tgi_request_duration", total_time);
    metrics::histogram!("tgi_request_validation_duration", validation_time);
    metrics::histogram!("tgi_request_queue_duration", queue_time);
    metrics::histogram!("tgi_request_inference_duration", inference_time);
    metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
    metrics::histogram!(
        "tgi_request_generated_tokens",
        response.generated_text.generated_tokens as f64
    );

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
204
    // Send response
205
206
207
208
209
    let mut output_text = response.generated_text.text;
    if let Some(prompt) = add_prompt {
        output_text = prompt + &output_text;
    }

210
    let response = GenerateResponse {
211
        generated_text: output_text,
OlivierDehaene's avatar
OlivierDehaene committed
212
        details,
213
    };
214
    Ok((headers, Json(response)))
Olivier Dehaene's avatar
Olivier Dehaene committed
215
216
}

Yannic Kilcher's avatar
Yannic Kilcher committed
217
/// Generate a stream of token using Server-Sent Events
218
219
220
221
222
223
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/generate_stream",
    request_body = GenerateRequest,
    responses(
224
        (status = 200, description = "Generated Text", body = StreamResponse,
225
            content_type = "text/event-stream"),
226
        (status = 424, description = "Generation Error", body = ErrorResponse,
227
228
            example = json ! ({"error": "Request failed during generation"}),
            content_type = "text/event-stream"),
229
        (status = 429, description = "Model is overloaded", body = ErrorResponse,
230
231
            example = json ! ({"error": "Model is overloaded"}),
            content_type = "text/event-stream"),
232
        (status = 422, description = "Input validation error", body = ErrorResponse,
233
234
            example = json ! ({"error": "Input validation error"}),
            content_type = "text/event-stream"),
235
        (status = 500, description = "Incomplete generation", body = ErrorResponse,
236
237
            example = json ! ({"error": "Incomplete generation"}),
            content_type = "text/event-stream"),
238
239
    )
)]
240
241
242
243
244
245
246
#[instrument(
    skip(infer),
    fields(
        total_time,
        validation_time,
        queue_time,
        inference_time,
247
248
        time_per_token,
        seed,
249
250
251
252
253
    )
)]
async fn generate_stream(
    infer: Extension<Infer>,
    req: Json<GenerateRequest>,
254
255
256
257
) -> (
    HeaderMap,
    Sse<impl Stream<Item = Result<Event, Infallible>>>,
) {
258
259
260
    let span = tracing::Span::current();
    let start_time = Instant::now();

261
262
263
264
265
266
267
268
269
    let compute_characters = req.0.inputs.chars().count();

    let mut headers = HeaderMap::new();
    headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
    headers.insert(
        "x-compute-characters",
        compute_characters.to_string().parse().unwrap(),
    );

270
271
272
273
    let stream = async_stream::stream! {
        // Inference
        let mut end_reached = false;
        let mut error = false;
274
275
276
277
278

        let mut add_prompt = None;
        if req.0.parameters.return_full_text.unwrap_or(false) {
            add_prompt = Some(req.0.inputs.clone());
        }
279
280
        let details = req.0.parameters.details;

281
        match infer.generate_stream(req.0).instrument(info_span!(parent: &span, "async_stream")).await {
282
            Ok(mut response_stream) => {
Yannic Kilcher's avatar
Yannic Kilcher committed
283
                // Server-Sent Event stream
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
                while let Some(response) = response_stream.next().await {
                    match response {
                        Ok(response) => {
                            match response {
                                // Prefill is ignored
                                InferStreamResponse::Prefill(_) => {}
                                // Yield event for every new token
                                InferStreamResponse::Token(token) => {
                                    // StreamResponse
                                    let stream_token = StreamResponse {
                                        token,
                                        generated_text: None,
                                        details: None,
                                    };

                                    yield Ok(Event::default().json_data(stream_token).unwrap())
                                }
                                // Yield event for last token and compute timings
                                InferStreamResponse::End {
                                    token,
                                    generated_text,
                                    start,
                                    queued,
                                } => {
                                    // Token details
                                    let details = match details {
310
311
                                        true => Some(StreamDetails {
                                            finish_reason: FinishReason::from(generated_text.finish_reason),
312
313
314
315
316
317
318
319
320
321
322
323
324
325
                                            generated_tokens: generated_text.generated_tokens,
                                            seed: generated_text.seed,
                                        }),
                                        false => None,
                                    };

                                    // Timings
                                    let total_time = start_time.elapsed();
                                    let validation_time = queued - start_time;
                                    let queue_time = start - queued;
                                    let inference_time = Instant::now() - start;
                                    let time_per_token = inference_time / generated_text.generated_tokens;

                                    // Tracing metadata
326
327
328
329
330
                                    span.record("total_time", format!("{total_time:?}"));
                                    span.record("validation_time", format!("{validation_time:?}"));
                                    span.record("queue_time", format!("{queue_time:?}"));
                                    span.record("inference_time", format!("{inference_time:?}"));
                                    span.record("time_per_token", format!("{time_per_token:?}"));
331
                                    span.record("seed", format!("{:?}", generated_text.seed));
332
333
                                    tracing::info!(parent: &span, "Output: {}", generated_text.text);

334
335
336
337
338
339
340
341
342
                                    // Metrics
                                    metrics::increment_counter!("tgi_request_success");
                                    metrics::histogram!("tgi_request_duration", total_time);
                                    metrics::histogram!("tgi_request_validation_duration", validation_time);
                                    metrics::histogram!("tgi_request_queue_duration", queue_time);
                                    metrics::histogram!("tgi_request_inference_duration", inference_time);
                                    metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token);
                                    metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);

343
344
                                    // StreamResponse
                                    end_reached = true;
345
346
347
348
349
350

                                    let mut output_text = generated_text.text;
                                    if let Some(prompt) = add_prompt {
                                        output_text = prompt + &output_text;
                                    }

351
352
                                    let stream_token = StreamResponse {
                                        token,
353
                                        generated_text: Some(output_text),
354
355
356
                                        details
                                    };

357
358
                                    yield Ok(Event::default().json_data(stream_token).unwrap());
                                    break;
359
360
361
                                }
                            }
                        }
362
                        // yield error
363
364
                        Err(err) => {
                            error = true;
365
366
                            yield Ok(Event::from(err));
                            break;
367
368
369
370
                        }
                    }
                }
            },
371
            // yield error
372
373
            Err(err) => {
                error = true;
374
                yield Ok(Event::from(err));
375
376
377
378
379
380
            }
        }
        // Check if generation reached the end
        // Skip if we already sent an error
        if !end_reached && !error {
            let err = InferError::IncompleteGeneration;
381
            metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
382
            tracing::error!("{err}");
383
            yield Ok(Event::from(err));
384
385
386
        }
    };

387
    (headers, Sse::new(stream).keep_alive(KeepAlive::default()))
388
389
}

390
391
392
393
394
395
396
397
398
399
400
/// Prometheus metrics scrape endpoint
#[utoipa::path(
    get,
    tag = "Text Generation Inference",
    path = "/metrics",
    responses((status = 200, description = "Prometheus Metrics", body = String))
)]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
    prom_handle.render()
}

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
401
402
403
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
404
    compat_return_full_text: bool,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
405
    max_concurrent_requests: usize,
406
    max_stop_sequences: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
407
    max_input_length: usize,
408
    max_total_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
409
    max_batch_size: usize,
410
    max_waiting_tokens: usize,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
411
412
413
414
    client: ShardedClient,
    tokenizer: Tokenizer,
    validation_workers: usize,
    addr: SocketAddr,
415
    allow_origin: Option<AllowOrigin>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
416
) {
417
418
419
420
421
422
    // OpenAPI documentation
    #[derive(OpenApi)]
    #[openapi(
        paths(
            generate,
            generate_stream,
423
            metrics,
424
425
426
427
428
        ),
        components(
            schemas(
                GenerateRequest,
                GenerateParameters,
429
                PrefillToken,
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
                Token,
                GenerateResponse,
                Details,
                FinishReason,
                StreamResponse,
                StreamDetails,
                ErrorResponse,
            )
        ),
        tags(
            (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
        ),
        info(
            title = "Text Generation Inference",
            license(
                name = "Apache 2.0",
                url = "https://www.apache.org/licenses/LICENSE-2.0"
            )
        )
    )]
    struct ApiDoc;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
452
    // Create state
453
454
455
456
457
458
459
    let validation = Validation::new(
        validation_workers,
        tokenizer,
        max_stop_sequences,
        max_input_length,
        max_total_tokens,
    );
460
461
    let infer = Infer::new(
        client,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
462
        validation,
463
464
465
466
        max_batch_size,
        max_waiting_tokens,
        max_concurrent_requests,
    );
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
467

468
469
470
471
472
473
    // Prometheus handler
    let builder = PrometheusBuilder::new();
    let prom_handle = builder
        .install_recorder()
        .expect("failed to install metrics recorder");

474
475
476
477
478
479
480
    // CORS layer
    let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
    let cors_layer = CorsLayer::new()
        .allow_methods([Method::GET, Method::POST])
        .allow_headers([http::header::CONTENT_TYPE])
        .allow_origin(allow_origin);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
481
    // Create router
Olivier Dehaene's avatar
Olivier Dehaene committed
482
    let app = Router::new()
483
        .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
484
        .route("/", post(compat_generate))
Olivier Dehaene's avatar
Olivier Dehaene committed
485
        .route("/generate", post(generate))
486
        .route("/generate_stream", post(generate_stream))
487
        .route("/", get(health))
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
488
        .route("/health", get(health))
489
        .route("/metrics", get(metrics))
490
491
        .layer(Extension(compat_return_full_text))
        .layer(Extension(infer))
492
        .layer(Extension(prom_handle))
493
494
        .layer(opentelemetry_tracing_layer())
        .layer(cors_layer);
Olivier Dehaene's avatar
Olivier Dehaene committed
495

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
496
    // Run server
Olivier Dehaene's avatar
Olivier Dehaene committed
497
    axum::Server::bind(&addr)
Olivier Dehaene's avatar
Olivier Dehaene committed
498
        .serve(app.into_make_service())
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
499
500
        // Wait until all requests are finished to shut down
        .with_graceful_shutdown(shutdown_signal())
Olivier Dehaene's avatar
Olivier Dehaene committed
501
502
        .await
        .unwrap();
Olivier Dehaene's avatar
Olivier Dehaene committed
503
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529

/// Shutdown signal handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {},
        _ = terminate => {},
    }

    tracing::info!("signal received, starting graceful shutdown");
530
    opentelemetry::global::shutdown_tracer_provider();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
531
}
532

533
534
535
536
537
538
539
540
541
542
543
impl From<i32> for FinishReason {
    fn from(finish_reason: i32) -> Self {
        let finish_reason = text_generation_client::FinishReason::from_i32(finish_reason).unwrap();
        match finish_reason {
            text_generation_client::FinishReason::Length => FinishReason::Length,
            text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
            text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
        }
    }
}

544
545
546
547
548
549
550
551
552
553
554
555
556
557
/// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
    fn from(err: InferError) -> Self {
        let status_code = match err {
            InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
            InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
            InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
            InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
        };

        (
            status_code,
            Json(ErrorResponse {
                error: err.to_string(),
558
                error_type: err.error_type().to_string(),
559
560
561
562
563
564
565
566
567
568
            }),
        )
    }
}

impl From<InferError> for Event {
    fn from(err: InferError) -> Self {
        Event::default()
            .json_data(ErrorResponse {
                error: err.to_string(),
569
                error_type: err.error_type().to_string(),
570
571
572
573
            })
            .unwrap()
    }
}