server.rs 54.3 KB
Newer Older
1
use crate::config::Config;
2
/// HTTP Server logic
3
use crate::health::Health;
4
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
5
use crate::validation::ValidationError;
6
use crate::{
7
8
9
10
11
12
13
14
15
    BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
    GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
    PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage,
    Validation,
};
use crate::{
    ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
    ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
    ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
OlivierDehaene's avatar
OlivierDehaene committed
16
    CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
17
};
18
use crate::{FunctionDefinition, ToolCall, ToolType};
Olivier Dehaene's avatar
Olivier Dehaene committed
19
use axum::extract::Extension;
20
use axum::http::{HeaderMap, Method, StatusCode};
21
use axum::response::sse::{Event, KeepAlive, Sse};
22
use axum::response::{IntoResponse, Response};
Olivier Dehaene's avatar
Olivier Dehaene committed
23
use axum::routing::{get, post};
24
use axum::{http, Json, Router};
Nicolas Patry's avatar
Nicolas Patry committed
25
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
drbh's avatar
drbh committed
26
use futures::stream::FuturesUnordered;
27
use futures::stream::StreamExt;
28
use futures::Stream;
drbh's avatar
drbh committed
29
use futures::TryStreamExt;
30
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
drbh's avatar
drbh committed
31
use serde_json::Value;
32
use std::convert::Infallible;
Olivier Dehaene's avatar
Olivier Dehaene committed
33
use std::net::SocketAddr;
34
35
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
36
use text_generation_client::{ShardInfo, ShardedClient};
Olivier Dehaene's avatar
Olivier Dehaene committed
37
use tokenizers::Tokenizer;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
38
use tokio::signal;
Olivier Dehaene's avatar
Olivier Dehaene committed
39
use tokio::time::Instant;
40
use tower_http::cors::{AllowOrigin, CorsLayer};
41
use tracing::{info_span, instrument, Instrument};
42
43
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
Olivier Dehaene's avatar
Olivier Dehaene committed
44

45
46
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path(
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
post,
tag = "Text Generation Inference",
path = "/",
request_body = CompatGenerateRequest,
responses(
(status = 200, description = "Generated Text",
content(
("application/json" = GenerateResponse),
("text/event-stream" = StreamResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
66
)]
67
#[instrument(skip(infer, req))]
68
async fn compat_generate(
69
    Extension(default_return_full_text): Extension<bool>,
70
    infer: Extension<Infer>,
71
    compute_type: Extension<ComputeType>,
72
    Json(mut req): Json<CompatGenerateRequest>,
73
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
74
75
    // default return_full_text given the pipeline_tag
    if req.parameters.return_full_text.is_none() {
76
        req.parameters.return_full_text = Some(default_return_full_text)
77
78
    }

79
80
    // switch on stream
    if req.stream {
81
        Ok(generate_stream(infer, compute_type, Json(req.into()))
82
83
84
            .await
            .into_response())
    } else {
85
        let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
86
        // wrap generation inside a Vec to match api-inference
87
        Ok((headers, Json(vec![generation])).into_response())
88
89
90
    }
}

91
92
/// Text Generation Inference endpoint info
#[utoipa::path(
93
94
95
96
get,
tag = "Text Generation Inference",
path = "/info",
responses((status = 200, description = "Served model info", body = Info))
97
98
)]
#[instrument]
99
100
async fn get_model_info(info: Extension<Info>) -> Json<Info> {
    Json(info.0)
101
102
}

103
#[utoipa::path(
104
105
106
107
108
109
110
111
get,
tag = "Text Generation Inference",
path = "/health",
responses(
(status = 200, description = "Everything is working fine"),
(status = 503, description = "Text generation inference is down", body = ErrorResponse,
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
)
112
113
)]
#[instrument(skip(health))]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
114
/// Health check method
115
116
117
118
119
120
121
122
123
124
125
async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
    match health.check().await {
        true => Ok(()),
        false => Err((
            StatusCode::SERVICE_UNAVAILABLE,
            Json(ErrorResponse {
                error: "unhealthy".to_string(),
                error_type: "healthcheck".to_string(),
            }),
        )),
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
126
127
}

128
129
/// Generate tokens
#[utoipa::path(
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
post,
tag = "Text Generation Inference",
path = "/generate",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = GenerateResponse),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
145
)]
146
#[instrument(
147
148
skip_all,
fields(
149
parameters = ? req.parameters,
150
151
152
153
154
155
156
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
157
)]
Olivier Dehaene's avatar
Olivier Dehaene committed
158
async fn generate(
159
    infer: Extension<Infer>,
160
    Extension(ComputeType(compute_type)): Extension<ComputeType>,
161
    Json(req): Json<GenerateRequest>,
162
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
163
    let span = tracing::Span::current();
164
    let start_time = Instant::now();
165
    metrics::increment_counter!("tgi_request_count");
166

167
168
    // Do not long ultra long inputs, like image payloads.
    tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
169

170
    let compute_characters = req.inputs.chars().count();
171
    let mut add_prompt = None;
172
173
    if req.parameters.return_full_text.unwrap_or(false) {
        add_prompt = Some(req.inputs.clone());
174
175
    }

Nicolas Patry's avatar
Nicolas Patry committed
176
    let details: bool = req.parameters.details || req.parameters.decoder_input_details;
177
178

    // Inference
179
    let (response, best_of_responses) = match req.parameters.best_of {
180
        Some(best_of) if best_of > 1 => {
181
            let (response, best_of_responses) = infer.generate_best_of(req, best_of).await?;
182
183
            (response, Some(best_of_responses))
        }
184
        _ => (infer.generate(req).await?, None),
185
    };
Olivier Dehaene's avatar
Olivier Dehaene committed
186

OlivierDehaene's avatar
OlivierDehaene committed
187
    // Token details
188
    let input_length = response._input_length;
OlivierDehaene's avatar
OlivierDehaene committed
189
    let details = match details {
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        true => {
            // convert best_of_responses
            let best_of_sequences = best_of_responses.map(|responses: Vec<InferResponse>| {
                responses
                    .into_iter()
                    .map(|response: InferResponse| {
                        // Add prompt if return_full_text
                        let mut output_text = response.generated_text.text;
                        if let Some(prompt) = &add_prompt {
                            output_text = prompt.clone() + &output_text;
                        }

                        BestOfSequence {
                            generated_text: output_text,
                            finish_reason: FinishReason::from(
                                response.generated_text.finish_reason,
                            ),
                            generated_tokens: response.generated_text.generated_tokens,
                            prefill: response.prefill,
                            tokens: response.tokens,
Nicolas Patry's avatar
Nicolas Patry committed
210
                            top_tokens: response.top_tokens,
211
212
213
214
215
216
217
218
219
220
221
222
223
                            seed: response.generated_text.seed,
                        }
                    })
                    .collect()
            });

            Some(Details {
                finish_reason: FinishReason::from(response.generated_text.finish_reason),
                generated_tokens: response.generated_text.generated_tokens,
                prefill: response.prefill,
                tokens: response.tokens,
                seed: response.generated_text.seed,
                best_of_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
224
                top_tokens: response.top_tokens,
225
226
            })
        }
OlivierDehaene's avatar
OlivierDehaene committed
227
228
229
        false => None,
    };

230
231
232
233
    // Timings
    let total_time = start_time.elapsed();
    let validation_time = response.queued - start_time;
    let queue_time = response.start - response.queued;
234
235
    let inference_time = Instant::now() - response.start;
    let time_per_token = inference_time / response.generated_text.generated_tokens;
236

237
238
239
240
241
242
243
244
    // Tracing metadata
    span.record("total_time", format!("{total_time:?}"));
    span.record("validation_time", format!("{validation_time:?}"));
    span.record("queue_time", format!("{queue_time:?}"));
    span.record("inference_time", format!("{inference_time:?}"));
    span.record("time_per_token", format!("{time_per_token:?}"));
    span.record("seed", format!("{:?}", response.generated_text.seed));

245
246
    // Headers
    let mut headers = HeaderMap::new();
247
    headers.insert("x-compute-type", compute_type.parse().unwrap());
248
249
    headers.insert(
        "x-compute-time",
Nicolas Patry's avatar
Nicolas Patry committed
250
        total_time.as_secs_f64().to_string().parse().unwrap(),
251
252
253
254
255
    );
    headers.insert(
        "x-compute-characters",
        compute_characters.to_string().parse().unwrap(),
    );
256
257
258
259
260
261
262
263
264
265
266
    headers.insert(
        "x-total-time",
        total_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-validation-time",
        validation_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-queue-time",
        queue_time.as_millis().to_string().parse().unwrap(),
Olivier Dehaene's avatar
Olivier Dehaene committed
267
    );
268
269
270
271
272
273
274
275
    headers.insert(
        "x-inference-time",
        inference_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-time-per-token",
        time_per_token.as_millis().to_string().parse().unwrap(),
    );
276
277
278
279
280
    headers.insert("x-prompt-tokens", input_length.into());
    headers.insert(
        "x-generated-tokens",
        response.generated_text.generated_tokens.into(),
    );
281

282
283
    // Metrics
    metrics::increment_counter!("tgi_request_success");
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    metrics::histogram!("tgi_request_duration", total_time.as_secs_f64());
    metrics::histogram!(
        "tgi_request_validation_duration",
        validation_time.as_secs_f64()
    );
    metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64());
    metrics::histogram!(
        "tgi_request_inference_duration",
        inference_time.as_secs_f64()
    );
    metrics::histogram!(
        "tgi_request_mean_time_per_token_duration",
        time_per_token.as_secs_f64()
    );
298
299
300
301
302
    metrics::histogram!(
        "tgi_request_generated_tokens",
        response.generated_text.generated_tokens as f64
    );

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
303
    // Send response
304
305
306
307
308
    let mut output_text = response.generated_text.text;
    if let Some(prompt) = add_prompt {
        output_text = prompt + &output_text;
    }

309
310
    tracing::debug!("Output: {}", output_text);
    tracing::info!("Success");
311

312
    let response = GenerateResponse {
313
        generated_text: output_text,
OlivierDehaene's avatar
OlivierDehaene committed
314
        details,
315
    };
316
    Ok((headers, Json(response)))
Olivier Dehaene's avatar
Olivier Dehaene committed
317
318
}

Yannic Kilcher's avatar
Yannic Kilcher committed
319
/// Generate a stream of token using Server-Sent Events
320
#[utoipa::path(
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
post,
tag = "Text Generation Inference",
path = "/generate_stream",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = StreamResponse,
content_type = "text/event-stream"),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"}),
content_type = "text/event-stream"),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"}),
content_type = "text/event-stream"),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"}),
content_type = "text/event-stream"),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"}),
content_type = "text/event-stream"),
)
341
)]
342
#[instrument(
343
344
skip_all,
fields(
345
parameters = ? req.parameters,
346
347
348
349
350
351
352
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
353
354
)]
async fn generate_stream(
355
    Extension(infer): Extension<Infer>,
356
    Extension(compute_type): Extension<ComputeType>,
357
    Json(req): Json<GenerateRequest>,
358
359
360
361
) -> (
    HeaderMap,
    Sse<impl Stream<Item = Result<Event, Infallible>>>,
) {
362
363
364
365
366
    let on_message_callback = |stream_token: StreamResponse| {
        let event = Event::default();
        event.json_data(stream_token).unwrap()
    };
    let (headers, response_stream) =
367
        generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await;
368
369
370
371
372
373
    let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
    (headers, sse)
}

async fn generate_stream_internal(
    infer: Infer,
374
    ComputeType(compute_type): ComputeType,
375
376
377
    Json(req): Json<GenerateRequest>,
    on_message_callback: impl Fn(StreamResponse) -> Event,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
378
379
    let span = tracing::Span::current();
    let start_time = Instant::now();
380
    metrics::increment_counter!("tgi_request_count");
381

382
    tracing::debug!("Input: {}", req.inputs);
383

384
    let compute_characters = req.inputs.chars().count();
385
386

    let mut headers = HeaderMap::new();
387
    headers.insert("x-compute-type", compute_type.parse().unwrap());
388
389
390
391
    headers.insert(
        "x-compute-characters",
        compute_characters.to_string().parse().unwrap(),
    );
392
    headers.insert("X-Accel-Buffering", "no".parse().unwrap());
393

394
395
396
397
    let stream = async_stream::stream! {
        // Inference
        let mut end_reached = false;
        let mut error = false;
398
399

        let mut add_prompt = None;
400
401
        if req.parameters.return_full_text.unwrap_or(false) {
            add_prompt = Some(req.inputs.clone());
402
        }
403
        let details = req.parameters.details;
404

405
        let best_of = req.parameters.best_of.unwrap_or(1);
406
407
408
409
410
        if best_of != 1 {
            let err = InferError::from(ValidationError::BestOfStream);
            metrics::increment_counter!("tgi_request_failure", "err" => "validation");
            tracing::error!("{err}");
            yield Ok(Event::from(err));
411
        } else if req.parameters.decoder_input_details {
412
413
414
415
416
            let err = InferError::from(ValidationError::PrefillDetailsStream);
            metrics::increment_counter!("tgi_request_failure", "err" => "validation");
            tracing::error!("{err}");
            yield Ok(Event::from(err));
        } else {
417
            match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
418
                // Keep permit as long as generate_stream lives
419
                Ok((_permit, _input_length, mut response_stream)) => {
420
                    let mut index = 0;
421
422
                    // Server-Sent Event stream
                    while let Some(response) = response_stream.next().await {
423
                        index += 1;
424
425
426
427
428
429
                        match response {
                            Ok(response) => {
                                match response {
                                    // Prefill is ignored
                                    InferStreamResponse::Prefill(_) => {}
                                    // Yield event for every new token
Nicolas Patry's avatar
Nicolas Patry committed
430
431
432
433
                                    InferStreamResponse::Intermediate{
                                        token,
                                        top_tokens,
                                    } => {
434
435
                                        tracing::debug!(parent: &span, "Token: {:?}", token);

436
437
                                        // StreamResponse
                                        let stream_token = StreamResponse {
438
                                            index,
439
                                            token,
Nicolas Patry's avatar
Nicolas Patry committed
440
                                            top_tokens,
441
442
443
                                            generated_text: None,
                                            details: None,
                                        };
444
445
                                        let event = on_message_callback(stream_token);
                                        yield Ok(event);
446
                                    }
447
448
                                    // Yield event for last token and compute timings
                                    InferStreamResponse::End {
449
                                        token,
450
451
452
                                        generated_text,
                                        start,
                                        queued,
Nicolas Patry's avatar
Nicolas Patry committed
453
                                        top_tokens,
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
                                    } => {
                                        // Token details
                                        let details = match details {
                                            true => Some(StreamDetails {
                                                finish_reason: FinishReason::from(generated_text.finish_reason),
                                                generated_tokens: generated_text.generated_tokens,
                                                seed: generated_text.seed,
                                            }),
                                            false => None,
                                        };

                                        // Timings
                                        let total_time = start_time.elapsed();
                                        let validation_time = queued - start_time;
                                        let queue_time = start - queued;
                                        let inference_time = Instant::now() - start;
                                        let time_per_token = inference_time / generated_text.generated_tokens;

                                        // Tracing metadata
                                        span.record("total_time", format!("{total_time:?}"));
                                        span.record("validation_time", format!("{validation_time:?}"));
                                        span.record("queue_time", format!("{queue_time:?}"));
                                        span.record("inference_time", format!("{inference_time:?}"));
                                        span.record("time_per_token", format!("{time_per_token:?}"));
                                        span.record("seed", format!("{:?}", generated_text.seed));

                                        // Metrics
                                        metrics::increment_counter!("tgi_request_success");
482
483
484
485
486
                                        metrics::histogram!("tgi_request_duration", total_time.as_secs_f64());
                                        metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64());
                                        metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64());
                                        metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64());
                                        metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64());
487
488
489
490
491
492
493
494
495
496
                                        metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);

                                        // StreamResponse
                                        end_reached = true;

                                        let mut output_text = generated_text.text;
                                        if let Some(prompt) = add_prompt {
                                            output_text = prompt + &output_text;
                                        }

497
498
                                        tracing::debug!(parent: &span, "Output: {}", output_text);
                                        tracing::info!(parent: &span, "Success");
499

500
                                        let stream_token = StreamResponse {
501
                                            index,
502
                                            token,
Nicolas Patry's avatar
Nicolas Patry committed
503
                                            top_tokens,
504
505
506
507
                                            generated_text: Some(output_text),
                                            details
                                        };

508
509
510

                                        let event = on_message_callback(stream_token);
                                        yield Ok(event);
511
512
                                        break;
                                    }
513
514
                                }
                            }
515
516
517
518
519
520
                            // yield error
                            Err(err) => {
                                error = true;
                                yield Ok(Event::from(err));
                                break;
                            }
521
522
                        }
                    }
523
524
525
526
527
                },
                // yield error
                Err(err) => {
                    error = true;
                    yield Ok(Event::from(err));
528
                }
529
530
531
532
533
534
535
            }
            // Check if generation reached the end
            // Skip if we already sent an error
            if !end_reached && !error {
                let err = InferError::IncompleteGeneration;
                metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
                tracing::error!("{err}");
536
                yield Ok(Event::from(err));
537
538
539
540
            }
        }
    };

541
542
543
    (headers, stream)
}

544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
/// Generate tokens
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/v1/completions",
    request_body = CompletionRequest,
    responses(
    (status = 200, description = "Generated Text", body = ChatCompletionChunk),
    (status = 424, description = "Generation Error", body = ErrorResponse,
    example = json ! ({"error": "Request failed during generation"})),
    (status = 429, description = "Model is overloaded", body = ErrorResponse,
    example = json ! ({"error": "Model is overloaded"})),
    (status = 422, description = "Input validation error", body = ErrorResponse,
    example = json ! ({"error": "Input validation error"})),
    (status = 500, description = "Incomplete generation", body = ErrorResponse,
    example = json ! ({"error": "Incomplete generation"})),
    )
    )]
#[instrument(
    skip_all,
    fields(
    // parameters = ? req.parameters,
    total_time,
    validation_time,
    queue_time,
    inference_time,
    time_per_token,
    seed,
    )
    )]
async fn completions(
    Extension(infer): Extension<Infer>,
    Extension(compute_type): Extension<ComputeType>,
    Extension(info): Extension<Info>,
    Json(req): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
    metrics::increment_counter!("tgi_request_count");

    let stream = req.stream;
    let max_new_tokens = req.max_tokens.or(Some(100));
    let seed = req.seed;

    // if suffix is present throw an error
    if req.suffix.is_some() {
        metrics::increment_counter!("tgi_request_failure", "err" => "validation");
        return Err((
            StatusCode::UNPROCESSABLE_ENTITY,
            Json(ErrorResponse {
                error: "Suffix is not supported and can be achieved by preprocessing the prompt."
                    .to_string(),
                error_type: "suffix not supported".to_string(),
            }),
        ));
    }

    // build the request passing some parameters
    let generate_request = GenerateRequest {
        inputs: req.prompt.to_string(),
        parameters: GenerateParameters {
            best_of: None,
            temperature: req.temperature,
            repetition_penalty: req.repetition_penalty,
            frequency_penalty: req.frequency_penalty,
            top_k: None,
            top_p: req.top_p,
            typical_p: None,
            do_sample: true,
            max_new_tokens,
            return_full_text: None,
            stop: Vec::new(),
            truncate: None,
            watermark: false,
            details: true,
            decoder_input_details: !stream,
            seed,
            top_n_tokens: None,
            grammar: None,
        },
    };

    if stream {
        let on_message_callback = move |stream_token: StreamResponse| {
            let event = Event::default();

            let current_time = std::time::SystemTime::now()
                .duration_since(std::time::UNIX_EPOCH)
                .unwrap_or_else(|_| std::time::Duration::from_secs(0))
                .as_secs();

            event
                .json_data(CompletionCompleteChunk {
                    id: "".to_string(),
                    object: "text_completion".to_string(),
                    created: current_time,

                    choices: vec![CompletionComplete {
                        finish_reason: "".to_string(),
                        index: 0,
                        logprobs: None,
                        text: stream_token.token.text,
                    }],

                    model: info.model_id.clone(),
                    system_fingerprint: format!(
                        "{}-{}",
                        info.version,
                        info.docker_label.unwrap_or("native")
                    ),
                })
                .map_or_else(
                    |e| {
                        println!("Failed to serialize ChatCompletionChunk: {:?}", e);
                        Event::default()
                    },
                    |data| data,
                )
        };

        let (headers, response_stream) = generate_stream_internal(
            infer,
            compute_type,
            Json(generate_request),
            on_message_callback,
        )
        .await;

        let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
        Ok((headers, sse).into_response())
    } else {
        let (headers, Json(generation)) = generate(
            Extension(infer),
            Extension(compute_type),
            Json(generate_request),
        )
        .await?;

        let current_time = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap_or_else(|_| std::time::Duration::from_secs(0))
            .as_secs();

        let details = generation.details.ok_or((
            // this should never happen but handle if details are missing unexpectedly
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(ErrorResponse {
                error: "No details in generation".to_string(),
                error_type: "no details".to_string(),
            }),
        ))?;

        let response = Completion {
            id: "".to_string(),
            object: "text_completion".to_string(),
            created: current_time,
            model: info.model_id.clone(),
            system_fingerprint: format!(
                "{}-{}",
                info.version,
                info.docker_label.unwrap_or("native")
            ),
            choices: vec![CompletionComplete {
                finish_reason: details.finish_reason.to_string(),
                index: 0,
                logprobs: None,
                text: generation.generated_text,
            }],
            usage: Usage {
                prompt_tokens: details.prefill.len() as u32,
                completion_tokens: details.generated_tokens,
                total_tokens: details.prefill.len() as u32 + details.generated_tokens,
            },
        };

        Ok((headers, Json(response)).into_response())
    }
}

721
722
723
724
725
726
727
/// Generate tokens
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/v1/chat/completions",
    request_body = ChatRequest,
    responses(
728
    (status = 200, description = "Generated Text", body = ChatCompletionChunk),
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
    (status = 424, description = "Generation Error", body = ErrorResponse,
    example = json ! ({"error": "Request failed during generation"})),
    (status = 429, description = "Model is overloaded", body = ErrorResponse,
    example = json ! ({"error": "Model is overloaded"})),
    (status = 422, description = "Input validation error", body = ErrorResponse,
    example = json ! ({"error": "Input validation error"})),
    (status = 500, description = "Incomplete generation", body = ErrorResponse,
    example = json ! ({"error": "Incomplete generation"})),
    )
    )]
#[instrument(
    skip_all,
    fields(
    // parameters = ? req.parameters,
    total_time,
    validation_time,
    queue_time,
    inference_time,
    time_per_token,
    seed,
    )
    )]
async fn chat_completions(
    Extension(infer): Extension<Infer>,
753
    Extension(compute_type): Extension<ComputeType>,
754
755
756
757
758
    Extension(info): Extension<Info>,
    Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
    metrics::increment_counter!("tgi_request_count");

759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
    let ChatRequest {
        logprobs,
        max_tokens,
        messages,
        presence_penalty,
        seed,
        stop,
        stream,
        tools,
        tool_choice,
        tool_prompt,
        ..
    } = req;

    let repetition_penalty = presence_penalty.map(|x| x + 2.0);
    let max_new_tokens = max_tokens.or(Some(100));
    let logprobs = logprobs.unwrap_or(false);
    let tool_prompt = tool_prompt.unwrap_or_default();
    let stop = stop.unwrap_or_default();

    // extract tool grammar if present
    let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
        Ok(grammar) => grammar,
782
783
784
785
786
787
788
789
790
791
792
793
794
        Err(err) => {
            metrics::increment_counter!("tgi_request_failure", "err" => "validation");
            tracing::error!("{err}");
            return Err((
                StatusCode::UNPROCESSABLE_ENTITY,
                Json(ErrorResponse {
                    error: err.to_string(),
                    error_type: err.error_type().to_string(),
                }),
            ));
        }
    };

795
796
797
    let grammar_with_prompt = tool_grammar
        .as_ref()
        .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
drbh's avatar
drbh committed
798

799
800
801
    let typed_grammar = grammar_with_prompt
        .as_ref()
        .map(|(grammar, _)| grammar.clone());
drbh's avatar
drbh committed
802

803
804
805
806
807
808
809
    // apply chat template to flatten the request into a single input
    let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
        Ok(inputs) => inputs,
        Err(err) => {
            metrics::increment_counter!("tgi_request_failure", "err" => "validation");
            tracing::error!("{err}");
            return Err((
drbh's avatar
drbh committed
810
811
                StatusCode::UNPROCESSABLE_ENTITY,
                Json(ErrorResponse {
812
813
                    error: err.to_string(),
                    error_type: err.error_type().to_string(),
drbh's avatar
drbh committed
814
                }),
815
816
            ));
        }
drbh's avatar
drbh committed
817
818
    };

819
820
821
822
823
    // build the request passing some parameters
    let generate_request = GenerateRequest {
        inputs: inputs.to_string(),
        parameters: GenerateParameters {
            best_of: None,
824
            temperature: req.temperature,
825
            repetition_penalty,
826
            frequency_penalty: req.frequency_penalty,
827
            top_k: None,
828
            top_p: req.top_p,
829
830
831
832
            typical_p: None,
            do_sample: true,
            max_new_tokens,
            return_full_text: None,
833
            stop,
834
835
836
            truncate: None,
            watermark: false,
            details: true,
837
            decoder_input_details: !stream,
838
            seed,
839
            top_n_tokens: req.top_logprobs,
840
            grammar: typed_grammar,
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
        },
    };

    // static values that will be returned in all cases
    let model_id = info.model_id.clone();
    let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));

    // switch on stream
    if stream {
        // pass this callback to the stream generation and build the required event structure
        let on_message_callback = move |stream_token: StreamResponse| {
            let event = Event::default();

            let current_time = std::time::SystemTime::now()
                .duration_since(std::time::UNIX_EPOCH)
                .unwrap_or_else(|_| std::time::Duration::from_secs(0))
                .as_secs();

859
860
861
862
            let logprobs = logprobs.then(|| {
                ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
            });

drbh's avatar
drbh committed
863
864
865
866
867
868
869
            // replace the content with the tool calls if grammar is present
            let (content, tool_calls) = if tool_grammar.is_some() {
                (None, Some(vec![stream_token.token.text]))
            } else {
                (Some(stream_token.token.text), None)
            };

870
871
872
873
            event
                .json_data(ChatCompletionChunk::new(
                    model_id.clone(),
                    system_fingerprint.clone(),
drbh's avatar
drbh committed
874
875
                    content,
                    tool_calls,
876
                    current_time,
877
                    logprobs,
878
879
880
881
882
883
884
885
886
887
888
                    stream_token.details.map(|d| d.finish_reason.to_string()),
                ))
                .map_or_else(
                    |e| {
                        println!("Failed to serialize ChatCompletionChunk: {:?}", e);
                        Event::default()
                    },
                    |data| data,
                )
        };

889
890
891
892
893
894
895
        let (headers, response_stream) = generate_stream_internal(
            infer,
            compute_type,
            Json(generate_request),
            on_message_callback,
        )
        .await;
896
897
898
        let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
        Ok((headers, sse).into_response())
    } else {
899
900
901
902
903
904
        let (headers, Json(generation)) = generate(
            Extension(infer),
            Extension(compute_type),
            Json(generate_request),
        )
        .await?;
905
906
907
908
909
910

        let current_time = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap_or_else(|_| std::time::Duration::from_secs(0))
            .as_secs();

drbh's avatar
drbh committed
911
912
913
914
915
916
917
918
919
920
921
922
        let (tool_calls, output) = if tool_grammar.is_some() {
            // gen_text should be valid json
            let gen_text_value: Value =
                serde_json::from_str(&generation.generated_text).map_err(|e| {
                    (
                        StatusCode::UNPROCESSABLE_ENTITY,
                        Json(ErrorResponse {
                            error: e.to_string(),
                            error_type: "Input validation error".to_string(),
                        }),
                    )
                })?;
923
            let tool_calls = vec![ToolCall {
drbh's avatar
drbh committed
924
925
926
927
                id: 0,
                r#type: "function".to_string(),
                function: FunctionDefinition {
                    description: None,
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
                    name: gen_text_value
                        .get("function")
                        .and_then(|f| f.get("_name"))
                        .and_then(|name| name.as_str())
                        .unwrap_or("default_function_name")
                        .to_string(),
                    // Serialize the JSON object obtained from "function" to an escaped JSON string
                    arguments: gen_text_value
                        .get("function")
                        .map(|f| {
                            let mut f_cloned = f.clone();
                            if let Value::Object(ref mut props) = f_cloned {
                                props.remove("_name");
                            }
                            f_cloned
                        })
                        .unwrap_or_default(),
drbh's avatar
drbh committed
945
                },
946
947
            }];
            (Some(tool_calls), None)
drbh's avatar
drbh committed
948
949
950
        } else {
            (None, Some(generation.generated_text))
        };
951
952
953
954
        // build the complete response object with the full text
        let response = ChatCompletion::new(
            model_id,
            system_fingerprint,
drbh's avatar
drbh committed
955
            output,
956
957
958
            current_time,
            generation.details.unwrap(),
            logprobs,
drbh's avatar
drbh committed
959
            tool_calls,
960
961
962
963
964
        );

        // wrap generation inside a Vec to match api-inference
        Ok((headers, Json(response)).into_response())
    }
965
966
}

drbh's avatar
drbh committed
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
/// Generate tokens from Vertex request
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/vertex",
    request_body = VertexRequest,
    responses(
    (status = 200, description = "Generated Text", body = VertexResponse),
    (status = 424, description = "Generation Error", body = ErrorResponse,
    example = json ! ({"error": "Request failed during generation"})),
    (status = 429, description = "Model is overloaded", body = ErrorResponse,
    example = json ! ({"error": "Model is overloaded"})),
    (status = 422, description = "Input validation error", body = ErrorResponse,
    example = json ! ({"error": "Input validation error"})),
    (status = 500, description = "Incomplete generation", body = ErrorResponse,
    example = json ! ({"error": "Incomplete generation"})),
    )
    )]
#[instrument(
    skip_all,
    fields(
        total_time,
        validation_time,
        queue_time,
        inference_time,
        time_per_token,
        seed,
    )
)]
async fn vertex_compatibility(
    Extension(infer): Extension<Infer>,
    Extension(compute_type): Extension<ComputeType>,
    Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
    metrics::increment_counter!("tgi_request_count");

    // check that theres at least one instance
    if req.instances.is_empty() {
        return Err((
            StatusCode::UNPROCESSABLE_ENTITY,
            Json(ErrorResponse {
                error: "Input validation error".to_string(),
                error_type: "Input validation error".to_string(),
            }),
        ));
    }

    // Process all instances
    let predictions = req
        .instances
        .iter()
        .map(|instance| {
            let generate_request = GenerateRequest {
                inputs: instance.inputs.clone(),
                parameters: GenerateParameters {
                    do_sample: true,
                    max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
                    seed: instance.parameters.as_ref().and_then(|p| p.seed),
                    details: true,
                    decoder_input_details: true,
                    ..Default::default()
                },
            };

            async {
                generate(
                    Extension(infer.clone()),
                    Extension(compute_type.clone()),
                    Json(generate_request),
                )
                .await
                .map(|(_, Json(generation))| generation.generated_text)
                .map_err(|_| {
                    (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        Json(ErrorResponse {
                            error: "Incomplete generation".into(),
                            error_type: "Incomplete generation".into(),
                        }),
                    )
                })
            }
        })
        .collect::<FuturesUnordered<_>>()
        .try_collect::<Vec<_>>()
        .await?;

    let response = VertexResponse { predictions };
    Ok((HeaderMap::new(), Json(response)).into_response())
}

1058
1059
1060
1061
1062
/// Tokenize inputs
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/tokenize",
1063
    request_body = GenerateRequest,
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
    responses(
    (status = 200, description = "Tokenized ids", body = TokenizeResponse),
    (status = 404, description = "No tokenizer found", body = ErrorResponse,
    example = json ! ({"error": "No fast tokenizer available"})),
    )
    )]
#[instrument(skip_all)]
async fn tokenize(
    Extension(infer): Extension<Infer>,
    Json(req): Json<GenerateRequest>,
1074
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
    let input = req.inputs.clone();
    let encoding = infer.tokenize(req).await?;
    if let Some(encoding) = encoding {
        let tokens: Vec<SimpleToken> = encoding
            .get_ids()
            .iter()
            .zip(encoding.get_offsets())
            .map(|(&id, &(start, stop))| {
                let text: String = input.chars().skip(start).take(stop - start).collect();
                SimpleToken {
                    id,
                    text,
                    start,
                    stop,
                }
            })
            .collect();
1092
        Ok(Json(TokenizeResponse(tokens)))
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
    } else {
        Err((
            StatusCode::NOT_FOUND,
            Json(ErrorResponse {
                error: "No fast tokenizer or tokenizer.json for this model".to_string(),
                error_type: "no fast tokenizer".to_string(),
            }),
        ))
    }
}

1104
1105
/// Prometheus metrics scrape endpoint
#[utoipa::path(
1106
1107
1108
1109
get,
tag = "Text Generation Inference",
path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String))
1110
1111
1112
1113
1114
)]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
    prom_handle.render()
}

1115
1116
1117
#[derive(Clone, Debug)]
pub(crate) struct ComputeType(String);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1118
1119
1120
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
1121
1122
    model_info: HubModelInfo,
    shard_info: ShardInfo,
1123
    compat_return_full_text: bool,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1124
    max_concurrent_requests: usize,
1125
    max_best_of: usize,
1126
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
1127
    max_top_n_tokens: u32,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1128
    max_input_length: usize,
1129
    max_total_tokens: usize,
1130
    waiting_served_ratio: f32,
1131
    max_batch_prefill_tokens: u32,
1132
    max_batch_total_tokens: u32,
1133
    max_waiting_tokens: usize,
1134
    max_batch_size: Option<usize>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1135
    client: ShardedClient,
1136
    tokenizer: Option<Tokenizer>,
1137
    config: Option<Config>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1138
1139
    validation_workers: usize,
    addr: SocketAddr,
1140
    allow_origin: Option<AllowOrigin>,
1141
1142
    ngrok: bool,
    ngrok_authtoken: Option<String>,
1143
    ngrok_edge: Option<String>,
1144
    tokenizer_config: HubTokenizerConfig,
1145
    messages_api_enabled: bool,
drbh's avatar
drbh committed
1146
    grammar_support: bool,
1147
) -> Result<(), axum::BoxError> {
1148
1149
1150
    // OpenAPI documentation
    #[derive(OpenApi)]
    #[openapi(
1151
1152
1153
1154
1155
1156
    paths(
    health,
    get_model_info,
    compat_generate,
    generate,
    generate_stream,
1157
    chat_completions,
1158
    completions,
1159
    tokenize,
1160
1161
1162
1163
1164
1165
1166
    metrics,
    ),
    components(
    schemas(
    Info,
    CompatGenerateRequest,
    GenerateRequest,
1167
    GrammarType,
1168
1169
    ChatRequest,
    Message,
1170
    ChatCompletionComplete,
1171
1172
1173
    ChatCompletionChoice,
    ChatCompletionDelta,
    ChatCompletionChunk,
1174
1175
1176
    ChatCompletionLogprob,
    ChatCompletionLogprobs,
    ChatCompletionTopLogprob,
1177
    ChatCompletion,
1178
1179
1180
    CompletionRequest,
    CompletionComplete,
    CompletionCompleteChunk,
1181
1182
1183
1184
    GenerateParameters,
    PrefillToken,
    Token,
    GenerateResponse,
1185
1186
    TokenizeResponse,
    SimpleToken,
1187
1188
1189
1190
1191
1192
    BestOfSequence,
    Details,
    FinishReason,
    StreamResponse,
    StreamDetails,
    ErrorResponse,
drbh's avatar
drbh committed
1193
    GrammarType,
1194
    Usage,
OlivierDehaene's avatar
OlivierDehaene committed
1195
1196
1197
1198
1199
1200
    DeltaToolCall,
    ToolType,
    Tool,
    ToolCall,
    Function,
    FunctionDefinition,
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
    )
    ),
    tags(
    (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
    ),
    info(
    title = "Text Generation Inference",
    license(
    name = "Apache 2.0",
    url = "https://www.apache.org/licenses/LICENSE-2.0"
    )
    )
1213
1214
1215
    )]
    struct ApiDoc;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1216
    // Create state
1217
1218
1219
    let validation = Validation::new(
        validation_workers,
        tokenizer,
1220
        config,
1221
        max_best_of,
1222
        max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
1223
        max_top_n_tokens,
1224
1225
        max_input_length,
        max_total_tokens,
drbh's avatar
drbh committed
1226
        grammar_support,
1227
    );
1228
1229
    let generation_health = Arc::new(AtomicBool::new(false));
    let health_ext = Health::new(client.clone(), generation_health.clone());
1230
1231
    let infer = Infer::new(
        client,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1232
        validation,
1233
        waiting_served_ratio,
1234
        max_batch_prefill_tokens,
1235
        max_batch_total_tokens,
1236
        max_waiting_tokens,
1237
        max_batch_size,
1238
        max_concurrent_requests,
1239
        shard_info.requires_padding,
1240
        shard_info.window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1241
        shard_info.speculate,
1242
        generation_health,
1243
        tokenizer_config,
1244
    );
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1245

1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
    // Duration buckets
    let duration_matcher = Matcher::Suffix(String::from("duration"));
    let n_duration_buckets = 35;
    let mut duration_buckets = Vec::with_capacity(n_duration_buckets);
    // Minimum duration in seconds
    let mut value = 0.0001;
    for _ in 0..n_duration_buckets {
        // geometric sequence
        value *= 1.5;
        duration_buckets.push(value);
    }
    // Input Length buckets
    let input_length_matcher = Matcher::Full(String::from("tgi_request_input_length"));
    let input_length_buckets: Vec<f64> = (0..100)
        .map(|x| (max_input_length as f64 / 100.0) * (x + 1) as f64)
        .collect();
    // Generated tokens buckets
    let generated_tokens_matcher = Matcher::Full(String::from("tgi_request_generated_tokens"));
    let generated_tokens_buckets: Vec<f64> = (0..100)
        .map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64)
        .collect();
    // Input Length buckets
    let max_new_tokens_matcher = Matcher::Full(String::from("tgi_request_max_new_tokens"));
    let max_new_tokens_buckets: Vec<f64> = (0..100)
        .map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64)
        .collect();
    // Batch size buckets
    let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
1274
    let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
OlivierDehaene's avatar
OlivierDehaene committed
1275
1276
1277
    // Speculated tokens buckets
    let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
    let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
1278

1279
    // Prometheus handler
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
    let builder = PrometheusBuilder::new()
        .set_buckets_for_metric(duration_matcher, &duration_buckets)
        .unwrap()
        .set_buckets_for_metric(input_length_matcher, &input_length_buckets)
        .unwrap()
        .set_buckets_for_metric(generated_tokens_matcher, &generated_tokens_buckets)
        .unwrap()
        .set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
        .unwrap()
        .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
OlivierDehaene's avatar
OlivierDehaene committed
1290
1291
        .unwrap()
        .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
1292
        .unwrap();
1293
1294
1295
1296
    let prom_handle = builder
        .install_recorder()
        .expect("failed to install metrics recorder");

1297
1298
1299
1300
1301
1302
1303
    // CORS layer
    let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
    let cors_layer = CorsLayer::new()
        .allow_methods([Method::GET, Method::POST])
        .allow_headers([http::header::CONTENT_TYPE])
        .allow_origin(allow_origin);

1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
    // Endpoint info
    let info = Info {
        model_id: model_info.model_id,
        model_sha: model_info.sha,
        model_dtype: shard_info.dtype,
        model_device_type: shard_info.device_type,
        model_pipeline_tag: model_info.pipeline_tag,
        max_concurrent_requests,
        max_best_of,
        max_stop_sequences,
        max_input_length,
        max_total_tokens,
        waiting_served_ratio,
        max_batch_total_tokens,
        max_waiting_tokens,
1319
        max_batch_size,
1320
1321
1322
        validation_workers,
        version: env!("CARGO_PKG_VERSION"),
        sha: option_env!("VERGEN_GIT_SHA"),
1323
        docker_label: option_env!("DOCKER_LABEL"),
1324
1325
    };

drbh's avatar
drbh committed
1326
1327
1328
1329
1330
    // Define VertextApiDoc conditionally only if the "google" feature is enabled
    let doc = {
        // avoid `mut` if possible
        #[cfg(feature = "google")]
        {
1331
1332
1333
1334
1335
1336
1337
1338
1339
            use crate::VertexInstance;

            #[derive(OpenApi)]
            #[openapi(
                paths(vertex_compatibility),
                components(schemas(VertexInstance, VertexRequest, VertexResponse))
            )]
            struct VertextApiDoc;

drbh's avatar
drbh committed
1340
            // limiting mutability to the smallest scope necessary
1341
            let mut doc = ApiDoc::openapi();
drbh's avatar
drbh committed
1342
1343
1344
1345
1346
1347
1348
            doc.merge(VertextApiDoc::openapi());
            doc
        }
        #[cfg(not(feature = "google"))]
        ApiDoc::openapi()
    };

1349
    // Configure Swagger UI
drbh's avatar
drbh committed
1350
    let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
1351
1352
1353

    // Define base and health routes
    let base_routes = Router::new()
1354
        .route("/", post(compat_generate))
1355
        .route("/", get(health))
1356
        .route("/info", get(get_model_info))
Olivier Dehaene's avatar
Olivier Dehaene committed
1357
        .route("/generate", post(generate))
1358
        .route("/generate_stream", post(generate_stream))
1359
        .route("/v1/chat/completions", post(chat_completions))
1360
        .route("/v1/completions", post(completions))
drbh's avatar
drbh committed
1361
        .route("/vertex", post(vertex_compatibility))
1362
        .route("/tokenize", post(tokenize))
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1363
        .route("/health", get(health))
1364
        .route("/ping", get(health))
1365
1366
1367
        .route("/metrics", get(metrics));

    // Conditional AWS Sagemaker route
1368
    let aws_sagemaker_route = if messages_api_enabled {
1369
1370
1371
1372
1373
        Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
    } else {
        Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
    };

1374
1375
    let compute_type =
        ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
1376

1377
    // Combine routes and layers
drbh's avatar
drbh committed
1378
    let mut app = Router::new()
1379
1380
        .merge(swagger_ui)
        .merge(base_routes)
drbh's avatar
drbh committed
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
        .merge(aws_sagemaker_route);

    #[cfg(feature = "google")]
    {
        tracing::info!("Built with `google` feature");
        tracing::info!(
            "Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
        );
        if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
            app = app.route(&env_predict_route, post(vertex_compatibility));
        }
        if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") {
            app = app.route(&env_health_route, get(health));
        }
    }

    // add layers after routes
    app = app
1399
        .layer(Extension(info))
1400
        .layer(Extension(health_ext.clone()))
1401
1402
        .layer(Extension(compat_return_full_text))
        .layer(Extension(infer))
1403
        .layer(Extension(compute_type))
1404
        .layer(Extension(prom_handle.clone()))
Nicolas Patry's avatar
Nicolas Patry committed
1405
        .layer(OtelAxumLayer::default())
1406
        .layer(cors_layer);
Olivier Dehaene's avatar
Olivier Dehaene committed
1407

1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
    if ngrok {
        #[cfg(feature = "ngrok")]
        {
            use ngrok::config::TunnelBuilder;

            let _ = addr;

            let authtoken =
                ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling");

1418
1419
1420
            let edge = ngrok_edge.expect("`ngrok-edge` must be set when using ngrok tunneling");

            let tunnel = ngrok::Session::builder()
1421
1422
1423
1424
                .authtoken(authtoken)
                .connect()
                .await
                .unwrap()
1425
1426
                .labeled_tunnel()
                .label("edge", edge);
1427
1428

            let listener = tunnel.listen().await.unwrap();
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443

            // Run prom metrics and health locally too
            tokio::spawn(
                axum::Server::bind(&addr)
                    .serve(
                        Router::new()
                            .route("/health", get(health))
                            .route("/metrics", get(metrics))
                            .layer(Extension(health_ext))
                            .layer(Extension(prom_handle))
                            .into_make_service(),
                    )
                    //Wait until all requests are finished to shut down
                    .with_graceful_shutdown(shutdown_signal()),
            );
1444
1445
1446
1447
1448
1449

            // Run server
            axum::Server::builder(listener)
                .serve(app.into_make_service())
                //Wait until all requests are finished to shut down
                .with_graceful_shutdown(shutdown_signal())
1450
                .await?;
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
        }
        #[cfg(not(feature = "ngrok"))]
        {
            let _ngrok_authtoken = ngrok_authtoken;
            let _ngrok_domain = ngrok_domain;
            let _ngrok_username = ngrok_username;
            let _ngrok_password = ngrok_password;

            panic!("`text-generation-router` was compiled without the `ngrok` feature");
        }
    } else {
        // Run server
        axum::Server::bind(&addr)
            .serve(app.into_make_service())
            // Wait until all requests are finished to shut down
            .with_graceful_shutdown(shutdown_signal())
1467
            .await?;
1468
    }
1469
    Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
1470
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496

/// Shutdown signal handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {},
        _ = terminate => {},
    }

    tracing::info!("signal received, starting graceful shutdown");
1497
    opentelemetry::global::shutdown_tracer_provider();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1498
}
1499

1500
1501
impl From<i32> for FinishReason {
    fn from(finish_reason: i32) -> Self {
Nicolas Patry's avatar
Nicolas Patry committed
1502
        let finish_reason = text_generation_client::FinishReason::try_from(finish_reason).unwrap();
1503
1504
1505
1506
1507
1508
1509
1510
        match finish_reason {
            text_generation_client::FinishReason::Length => FinishReason::Length,
            text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
            text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
        }
    }
}

1511
1512
1513
1514
1515
1516
1517
1518
/// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
    fn from(err: InferError) -> Self {
        let status_code = match err {
            InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
            InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
            InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
            InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
1519
            InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
1520
            InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
1521
1522
1523
1524
1525
1526
        };

        (
            status_code,
            Json(ErrorResponse {
                error: err.to_string(),
1527
                error_type: err.error_type().to_string(),
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
            }),
        )
    }
}

impl From<InferError> for Event {
    fn from(err: InferError) -> Self {
        Event::default()
            .json_data(ErrorResponse {
                error: err.to_string(),
1538
                error_type: err.error_type().to_string(),
1539
1540
1541
1542
            })
            .unwrap()
    }
}