server.rs 54.5 KB
Newer Older
1
use crate::config::Config;
2
/// HTTP Server logic
3
use crate::health::Health;
4
use crate::infer::{InferError, InferResponse, InferStreamResponse, ToolGrammar};
5
use crate::validation::ValidationError;
6
use crate::{
7
8
9
10
11
12
13
14
15
    BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
    GenerateResponse, GrammarType, HubModelInfo, HubTokenizerConfig, Infer, Info, Message,
    PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse, Usage,
    Validation,
};
use crate::{
    ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
    ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
    ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
OlivierDehaene's avatar
OlivierDehaene committed
16
    CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
17
};
18
use crate::{FunctionDefinition, ToolCall, ToolType};
Olivier Dehaene's avatar
Olivier Dehaene committed
19
use axum::extract::Extension;
20
use axum::http::{HeaderMap, Method, StatusCode};
21
use axum::response::sse::{Event, KeepAlive, Sse};
22
use axum::response::{IntoResponse, Response};
Olivier Dehaene's avatar
Olivier Dehaene committed
23
use axum::routing::{get, post};
24
use axum::{http, Json, Router};
Nicolas Patry's avatar
Nicolas Patry committed
25
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
drbh's avatar
drbh committed
26
use futures::stream::FuturesUnordered;
27
use futures::stream::StreamExt;
28
use futures::Stream;
drbh's avatar
drbh committed
29
use futures::TryStreamExt;
30
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
drbh's avatar
drbh committed
31
use serde_json::Value;
32
use std::convert::Infallible;
Olivier Dehaene's avatar
Olivier Dehaene committed
33
use std::net::SocketAddr;
34
35
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
36
use text_generation_client::{ShardInfo, ShardedClient};
Olivier Dehaene's avatar
Olivier Dehaene committed
37
use tokenizers::Tokenizer;
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
38
use tokio::signal;
Olivier Dehaene's avatar
Olivier Dehaene committed
39
use tokio::time::Instant;
40
use tower_http::cors::{AllowOrigin, CorsLayer};
41
use tracing::{info_span, instrument, Instrument};
42
43
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
Olivier Dehaene's avatar
Olivier Dehaene committed
44

45
46
/// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path(
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
post,
tag = "Text Generation Inference",
path = "/",
request_body = CompatGenerateRequest,
responses(
(status = 200, description = "Generated Text",
content(
("application/json" = GenerateResponse),
("text/event-stream" = StreamResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
66
)]
67
#[instrument(skip(infer, req))]
68
async fn compat_generate(
69
    Extension(default_return_full_text): Extension<bool>,
70
    infer: Extension<Infer>,
71
    compute_type: Extension<ComputeType>,
72
    Json(mut req): Json<CompatGenerateRequest>,
73
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
74
75
    // default return_full_text given the pipeline_tag
    if req.parameters.return_full_text.is_none() {
76
        req.parameters.return_full_text = Some(default_return_full_text)
77
78
    }

79
80
    // switch on stream
    if req.stream {
81
        Ok(generate_stream(infer, compute_type, Json(req.into()))
82
83
84
            .await
            .into_response())
    } else {
85
        let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
86
        // wrap generation inside a Vec to match api-inference
87
        Ok((headers, Json(vec![generation])).into_response())
88
89
90
    }
}

91
92
/// Text Generation Inference endpoint info
#[utoipa::path(
93
94
95
96
get,
tag = "Text Generation Inference",
path = "/info",
responses((status = 200, description = "Served model info", body = Info))
97
98
)]
#[instrument]
99
100
async fn get_model_info(info: Extension<Info>) -> Json<Info> {
    Json(info.0)
101
102
}

103
#[utoipa::path(
104
105
106
107
108
109
110
111
get,
tag = "Text Generation Inference",
path = "/health",
responses(
(status = 200, description = "Everything is working fine"),
(status = 503, description = "Text generation inference is down", body = ErrorResponse,
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
)
112
113
)]
#[instrument(skip(health))]
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
114
/// Health check method
115
116
117
118
119
120
121
122
123
124
125
async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
    match health.check().await {
        true => Ok(()),
        false => Err((
            StatusCode::SERVICE_UNAVAILABLE,
            Json(ErrorResponse {
                error: "unhealthy".to_string(),
                error_type: "healthcheck".to_string(),
            }),
        )),
    }
Olivier Dehaene's avatar
Olivier Dehaene committed
126
127
}

128
129
/// Generate tokens
#[utoipa::path(
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
post,
tag = "Text Generation Inference",
path = "/generate",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = GenerateResponse),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
145
)]
146
#[instrument(
147
148
skip_all,
fields(
149
parameters = ? req.parameters,
150
151
152
153
154
155
156
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
157
)]
Olivier Dehaene's avatar
Olivier Dehaene committed
158
async fn generate(
159
    infer: Extension<Infer>,
160
    Extension(ComputeType(compute_type)): Extension<ComputeType>,
161
    Json(req): Json<GenerateRequest>,
162
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
163
    let span = tracing::Span::current();
164
    let start_time = Instant::now();
165
    metrics::increment_counter!("tgi_request_count");
166

167
168
    // Do not long ultra long inputs, like image payloads.
    tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
169

170
    let compute_characters = req.inputs.chars().count();
171
    let mut add_prompt = None;
172
173
    if req.parameters.return_full_text.unwrap_or(false) {
        add_prompt = Some(req.inputs.clone());
174
175
    }

Nicolas Patry's avatar
Nicolas Patry committed
176
    let details: bool = req.parameters.details || req.parameters.decoder_input_details;
177
178

    // Inference
179
    let (response, best_of_responses) = match req.parameters.best_of {
180
        Some(best_of) if best_of > 1 => {
181
            let (response, best_of_responses) = infer.generate_best_of(req, best_of).await?;
182
183
            (response, Some(best_of_responses))
        }
184
        _ => (infer.generate(req).await?, None),
185
    };
Olivier Dehaene's avatar
Olivier Dehaene committed
186

OlivierDehaene's avatar
OlivierDehaene committed
187
    // Token details
188
    let input_length = response._input_length;
OlivierDehaene's avatar
OlivierDehaene committed
189
    let details = match details {
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        true => {
            // convert best_of_responses
            let best_of_sequences = best_of_responses.map(|responses: Vec<InferResponse>| {
                responses
                    .into_iter()
                    .map(|response: InferResponse| {
                        // Add prompt if return_full_text
                        let mut output_text = response.generated_text.text;
                        if let Some(prompt) = &add_prompt {
                            output_text = prompt.clone() + &output_text;
                        }

                        BestOfSequence {
                            generated_text: output_text,
                            finish_reason: FinishReason::from(
                                response.generated_text.finish_reason,
                            ),
                            generated_tokens: response.generated_text.generated_tokens,
                            prefill: response.prefill,
                            tokens: response.tokens,
Nicolas Patry's avatar
Nicolas Patry committed
210
                            top_tokens: response.top_tokens,
211
212
213
214
215
216
217
218
219
220
221
222
223
                            seed: response.generated_text.seed,
                        }
                    })
                    .collect()
            });

            Some(Details {
                finish_reason: FinishReason::from(response.generated_text.finish_reason),
                generated_tokens: response.generated_text.generated_tokens,
                prefill: response.prefill,
                tokens: response.tokens,
                seed: response.generated_text.seed,
                best_of_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
224
                top_tokens: response.top_tokens,
225
226
            })
        }
OlivierDehaene's avatar
OlivierDehaene committed
227
228
229
        false => None,
    };

230
231
232
233
    // Timings
    let total_time = start_time.elapsed();
    let validation_time = response.queued - start_time;
    let queue_time = response.start - response.queued;
234
235
    let inference_time = Instant::now() - response.start;
    let time_per_token = inference_time / response.generated_text.generated_tokens;
236

237
238
239
240
241
242
243
244
    // Tracing metadata
    span.record("total_time", format!("{total_time:?}"));
    span.record("validation_time", format!("{validation_time:?}"));
    span.record("queue_time", format!("{queue_time:?}"));
    span.record("inference_time", format!("{inference_time:?}"));
    span.record("time_per_token", format!("{time_per_token:?}"));
    span.record("seed", format!("{:?}", response.generated_text.seed));

245
246
    // Headers
    let mut headers = HeaderMap::new();
247
    headers.insert("x-compute-type", compute_type.parse().unwrap());
248
249
    headers.insert(
        "x-compute-time",
Nicolas Patry's avatar
Nicolas Patry committed
250
        total_time.as_secs_f64().to_string().parse().unwrap(),
251
252
253
254
255
    );
    headers.insert(
        "x-compute-characters",
        compute_characters.to_string().parse().unwrap(),
    );
256
257
258
259
260
261
262
263
264
265
266
    headers.insert(
        "x-total-time",
        total_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-validation-time",
        validation_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-queue-time",
        queue_time.as_millis().to_string().parse().unwrap(),
Olivier Dehaene's avatar
Olivier Dehaene committed
267
    );
268
269
270
271
272
273
274
275
    headers.insert(
        "x-inference-time",
        inference_time.as_millis().to_string().parse().unwrap(),
    );
    headers.insert(
        "x-time-per-token",
        time_per_token.as_millis().to_string().parse().unwrap(),
    );
276
277
278
279
280
    headers.insert("x-prompt-tokens", input_length.into());
    headers.insert(
        "x-generated-tokens",
        response.generated_text.generated_tokens.into(),
    );
281

282
283
    // Metrics
    metrics::increment_counter!("tgi_request_success");
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    metrics::histogram!("tgi_request_duration", total_time.as_secs_f64());
    metrics::histogram!(
        "tgi_request_validation_duration",
        validation_time.as_secs_f64()
    );
    metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64());
    metrics::histogram!(
        "tgi_request_inference_duration",
        inference_time.as_secs_f64()
    );
    metrics::histogram!(
        "tgi_request_mean_time_per_token_duration",
        time_per_token.as_secs_f64()
    );
298
299
300
301
302
    metrics::histogram!(
        "tgi_request_generated_tokens",
        response.generated_text.generated_tokens as f64
    );

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
303
    // Send response
304
305
306
307
308
    let mut output_text = response.generated_text.text;
    if let Some(prompt) = add_prompt {
        output_text = prompt + &output_text;
    }

309
310
    tracing::debug!("Output: {}", output_text);
    tracing::info!("Success");
311

312
    let response = GenerateResponse {
313
        generated_text: output_text,
OlivierDehaene's avatar
OlivierDehaene committed
314
        details,
315
    };
316
    Ok((headers, Json(response)))
Olivier Dehaene's avatar
Olivier Dehaene committed
317
318
}

Yannic Kilcher's avatar
Yannic Kilcher committed
319
/// Generate a stream of token using Server-Sent Events
320
#[utoipa::path(
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
post,
tag = "Text Generation Inference",
path = "/generate_stream",
request_body = GenerateRequest,
responses(
(status = 200, description = "Generated Text", body = StreamResponse,
content_type = "text/event-stream"),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"}),
content_type = "text/event-stream"),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"}),
content_type = "text/event-stream"),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"}),
content_type = "text/event-stream"),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"}),
content_type = "text/event-stream"),
)
341
)]
342
#[instrument(
343
344
skip_all,
fields(
345
parameters = ? req.parameters,
346
347
348
349
350
351
352
total_time,
validation_time,
queue_time,
inference_time,
time_per_token,
seed,
)
353
354
)]
async fn generate_stream(
355
    Extension(infer): Extension<Infer>,
356
    Extension(compute_type): Extension<ComputeType>,
357
    Json(req): Json<GenerateRequest>,
358
359
360
361
) -> (
    HeaderMap,
    Sse<impl Stream<Item = Result<Event, Infallible>>>,
) {
362
363
364
365
366
    let on_message_callback = |stream_token: StreamResponse| {
        let event = Event::default();
        event.json_data(stream_token).unwrap()
    };
    let (headers, response_stream) =
367
        generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await;
368
369
370
371
372
373
    let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
    (headers, sse)
}

async fn generate_stream_internal(
    infer: Infer,
374
    ComputeType(compute_type): ComputeType,
375
376
377
    Json(req): Json<GenerateRequest>,
    on_message_callback: impl Fn(StreamResponse) -> Event,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
378
379
    let span = tracing::Span::current();
    let start_time = Instant::now();
380
    metrics::increment_counter!("tgi_request_count");
381

382
    tracing::debug!("Input: {}", req.inputs);
383

384
    let compute_characters = req.inputs.chars().count();
385
386

    let mut headers = HeaderMap::new();
387
    headers.insert("x-compute-type", compute_type.parse().unwrap());
388
389
390
391
    headers.insert(
        "x-compute-characters",
        compute_characters.to_string().parse().unwrap(),
    );
392
    headers.insert("X-Accel-Buffering", "no".parse().unwrap());
393

394
395
396
397
    let stream = async_stream::stream! {
        // Inference
        let mut end_reached = false;
        let mut error = false;
398
399

        let mut add_prompt = None;
400
401
        if req.parameters.return_full_text.unwrap_or(false) {
            add_prompt = Some(req.inputs.clone());
402
        }
403
        let details = req.parameters.details;
404

405
        let best_of = req.parameters.best_of.unwrap_or(1);
406
407
408
409
410
        if best_of != 1 {
            let err = InferError::from(ValidationError::BestOfStream);
            metrics::increment_counter!("tgi_request_failure", "err" => "validation");
            tracing::error!("{err}");
            yield Ok(Event::from(err));
411
        } else if req.parameters.decoder_input_details {
412
413
414
415
416
            let err = InferError::from(ValidationError::PrefillDetailsStream);
            metrics::increment_counter!("tgi_request_failure", "err" => "validation");
            tracing::error!("{err}");
            yield Ok(Event::from(err));
        } else {
417
            match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
418
                // Keep permit as long as generate_stream lives
419
                Ok((_permit, _input_length, mut response_stream)) => {
420
                    let mut index = 0;
421
422
                    // Server-Sent Event stream
                    while let Some(response) = response_stream.next().await {
423
                        index += 1;
424
425
426
427
428
429
                        match response {
                            Ok(response) => {
                                match response {
                                    // Prefill is ignored
                                    InferStreamResponse::Prefill(_) => {}
                                    // Yield event for every new token
Nicolas Patry's avatar
Nicolas Patry committed
430
431
432
433
                                    InferStreamResponse::Intermediate{
                                        token,
                                        top_tokens,
                                    } => {
434
435
                                        tracing::debug!(parent: &span, "Token: {:?}", token);

436
437
                                        // StreamResponse
                                        let stream_token = StreamResponse {
438
                                            index,
439
                                            token,
Nicolas Patry's avatar
Nicolas Patry committed
440
                                            top_tokens,
441
442
443
                                            generated_text: None,
                                            details: None,
                                        };
444
445
                                        let event = on_message_callback(stream_token);
                                        yield Ok(event);
446
                                    }
447
448
                                    // Yield event for last token and compute timings
                                    InferStreamResponse::End {
449
                                        token,
450
451
452
                                        generated_text,
                                        start,
                                        queued,
Nicolas Patry's avatar
Nicolas Patry committed
453
                                        top_tokens,
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
                                    } => {
                                        // Token details
                                        let details = match details {
                                            true => Some(StreamDetails {
                                                finish_reason: FinishReason::from(generated_text.finish_reason),
                                                generated_tokens: generated_text.generated_tokens,
                                                seed: generated_text.seed,
                                            }),
                                            false => None,
                                        };

                                        // Timings
                                        let total_time = start_time.elapsed();
                                        let validation_time = queued - start_time;
                                        let queue_time = start - queued;
                                        let inference_time = Instant::now() - start;
                                        let time_per_token = inference_time / generated_text.generated_tokens;

                                        // Tracing metadata
                                        span.record("total_time", format!("{total_time:?}"));
                                        span.record("validation_time", format!("{validation_time:?}"));
                                        span.record("queue_time", format!("{queue_time:?}"));
                                        span.record("inference_time", format!("{inference_time:?}"));
                                        span.record("time_per_token", format!("{time_per_token:?}"));
                                        span.record("seed", format!("{:?}", generated_text.seed));

                                        // Metrics
                                        metrics::increment_counter!("tgi_request_success");
482
483
484
485
486
                                        metrics::histogram!("tgi_request_duration", total_time.as_secs_f64());
                                        metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64());
                                        metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64());
                                        metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64());
                                        metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64());
487
488
489
490
491
492
493
494
495
496
                                        metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);

                                        // StreamResponse
                                        end_reached = true;

                                        let mut output_text = generated_text.text;
                                        if let Some(prompt) = add_prompt {
                                            output_text = prompt + &output_text;
                                        }

497
498
                                        tracing::debug!(parent: &span, "Output: {}", output_text);
                                        tracing::info!(parent: &span, "Success");
499

500
                                        let stream_token = StreamResponse {
501
                                            index,
502
                                            token,
Nicolas Patry's avatar
Nicolas Patry committed
503
                                            top_tokens,
504
505
506
507
                                            generated_text: Some(output_text),
                                            details
                                        };

508
509
510

                                        let event = on_message_callback(stream_token);
                                        yield Ok(event);
511
512
                                        break;
                                    }
513
514
                                }
                            }
515
516
517
518
519
520
                            // yield error
                            Err(err) => {
                                error = true;
                                yield Ok(Event::from(err));
                                break;
                            }
521
522
                        }
                    }
523
524
525
526
527
                },
                // yield error
                Err(err) => {
                    error = true;
                    yield Ok(Event::from(err));
528
                }
529
530
531
532
533
534
535
            }
            // Check if generation reached the end
            // Skip if we already sent an error
            if !end_reached && !error {
                let err = InferError::IncompleteGeneration;
                metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
                tracing::error!("{err}");
536
                yield Ok(Event::from(err));
537
538
539
540
            }
        }
    };

541
542
543
    (headers, stream)
}

544
545
546
547
548
549
550
/// Generate tokens
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/v1/completions",
    request_body = CompletionRequest,
    responses(
551
552
553
554
555
    (status = 200, description = "Generated Chat Completion",
    content(
    ("application/json" = Completion),
    ("text/event-stream" = CompletionCompleteChunk),
    )),
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
    (status = 424, description = "Generation Error", body = ErrorResponse,
    example = json ! ({"error": "Request failed during generation"})),
    (status = 429, description = "Model is overloaded", body = ErrorResponse,
    example = json ! ({"error": "Model is overloaded"})),
    (status = 422, description = "Input validation error", body = ErrorResponse,
    example = json ! ({"error": "Input validation error"})),
    (status = 500, description = "Incomplete generation", body = ErrorResponse,
    example = json ! ({"error": "Incomplete generation"})),
    )
    )]
#[instrument(
    skip_all,
    fields(
    // parameters = ? req.parameters,
    total_time,
    validation_time,
    queue_time,
    inference_time,
    time_per_token,
    seed,
    )
    )]
async fn completions(
    Extension(infer): Extension<Infer>,
    Extension(compute_type): Extension<ComputeType>,
    Extension(info): Extension<Info>,
    Json(req): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
    metrics::increment_counter!("tgi_request_count");

    let stream = req.stream;
    let max_new_tokens = req.max_tokens.or(Some(100));
    let seed = req.seed;

    // if suffix is present throw an error
    if req.suffix.is_some() {
        metrics::increment_counter!("tgi_request_failure", "err" => "validation");
        return Err((
            StatusCode::UNPROCESSABLE_ENTITY,
            Json(ErrorResponse {
                error: "Suffix is not supported and can be achieved by preprocessing the prompt."
                    .to_string(),
                error_type: "suffix not supported".to_string(),
            }),
        ));
    }

    // build the request passing some parameters
    let generate_request = GenerateRequest {
        inputs: req.prompt.to_string(),
        parameters: GenerateParameters {
            best_of: None,
            temperature: req.temperature,
            repetition_penalty: req.repetition_penalty,
            frequency_penalty: req.frequency_penalty,
            top_k: None,
            top_p: req.top_p,
            typical_p: None,
            do_sample: true,
            max_new_tokens,
            return_full_text: None,
            stop: Vec::new(),
            truncate: None,
            watermark: false,
            details: true,
            decoder_input_details: !stream,
            seed,
            top_n_tokens: None,
            grammar: None,
        },
    };

    if stream {
        let on_message_callback = move |stream_token: StreamResponse| {
            let event = Event::default();

            let current_time = std::time::SystemTime::now()
                .duration_since(std::time::UNIX_EPOCH)
                .unwrap_or_else(|_| std::time::Duration::from_secs(0))
                .as_secs();

            event
                .json_data(CompletionCompleteChunk {
                    id: "".to_string(),
                    object: "text_completion".to_string(),
                    created: current_time,

                    choices: vec![CompletionComplete {
                        finish_reason: "".to_string(),
                        index: 0,
                        logprobs: None,
                        text: stream_token.token.text,
                    }],

                    model: info.model_id.clone(),
                    system_fingerprint: format!(
                        "{}-{}",
                        info.version,
                        info.docker_label.unwrap_or("native")
                    ),
                })
                .map_or_else(
                    |e| {
659
                        println!("Failed to serialize CompletionCompleteChunk: {:?}", e);
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
                        Event::default()
                    },
                    |data| data,
                )
        };

        let (headers, response_stream) = generate_stream_internal(
            infer,
            compute_type,
            Json(generate_request),
            on_message_callback,
        )
        .await;

        let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
        Ok((headers, sse).into_response())
    } else {
        let (headers, Json(generation)) = generate(
            Extension(infer),
            Extension(compute_type),
            Json(generate_request),
        )
        .await?;

        let current_time = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap_or_else(|_| std::time::Duration::from_secs(0))
            .as_secs();

        let details = generation.details.ok_or((
            // this should never happen but handle if details are missing unexpectedly
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(ErrorResponse {
                error: "No details in generation".to_string(),
                error_type: "no details".to_string(),
            }),
        ))?;

        let response = Completion {
            id: "".to_string(),
            object: "text_completion".to_string(),
            created: current_time,
            model: info.model_id.clone(),
            system_fingerprint: format!(
                "{}-{}",
                info.version,
                info.docker_label.unwrap_or("native")
            ),
            choices: vec![CompletionComplete {
                finish_reason: details.finish_reason.to_string(),
                index: 0,
                logprobs: None,
                text: generation.generated_text,
            }],
            usage: Usage {
                prompt_tokens: details.prefill.len() as u32,
                completion_tokens: details.generated_tokens,
                total_tokens: details.prefill.len() as u32 + details.generated_tokens,
            },
        };

        Ok((headers, Json(response)).into_response())
    }
}

725
726
727
728
729
730
731
/// Generate tokens
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/v1/chat/completions",
    request_body = ChatRequest,
    responses(
732
733
734
735
736
    (status = 200, description = "Generated Chat Completion",
    content(
    ("application/json" = ChatCompletion),
    ("text/event-stream" = ChatCompletionChunk),
    )),
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
    (status = 424, description = "Generation Error", body = ErrorResponse,
    example = json ! ({"error": "Request failed during generation"})),
    (status = 429, description = "Model is overloaded", body = ErrorResponse,
    example = json ! ({"error": "Model is overloaded"})),
    (status = 422, description = "Input validation error", body = ErrorResponse,
    example = json ! ({"error": "Input validation error"})),
    (status = 500, description = "Incomplete generation", body = ErrorResponse,
    example = json ! ({"error": "Incomplete generation"})),
    )
    )]
#[instrument(
    skip_all,
    fields(
    // parameters = ? req.parameters,
    total_time,
    validation_time,
    queue_time,
    inference_time,
    time_per_token,
    seed,
    )
    )]
async fn chat_completions(
    Extension(infer): Extension<Infer>,
761
    Extension(compute_type): Extension<ComputeType>,
762
763
764
765
766
    Extension(info): Extension<Info>,
    Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
    metrics::increment_counter!("tgi_request_count");

767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
    let ChatRequest {
        logprobs,
        max_tokens,
        messages,
        presence_penalty,
        seed,
        stop,
        stream,
        tools,
        tool_choice,
        tool_prompt,
        ..
    } = req;

    let repetition_penalty = presence_penalty.map(|x| x + 2.0);
    let max_new_tokens = max_tokens.or(Some(100));
    let logprobs = logprobs.unwrap_or(false);
    let tool_prompt = tool_prompt.unwrap_or_default();
    let stop = stop.unwrap_or_default();

    // extract tool grammar if present
    let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
        Ok(grammar) => grammar,
790
791
792
793
794
795
796
797
798
799
800
801
802
        Err(err) => {
            metrics::increment_counter!("tgi_request_failure", "err" => "validation");
            tracing::error!("{err}");
            return Err((
                StatusCode::UNPROCESSABLE_ENTITY,
                Json(ErrorResponse {
                    error: err.to_string(),
                    error_type: err.error_type().to_string(),
                }),
            ));
        }
    };

803
804
805
    let grammar_with_prompt = tool_grammar
        .as_ref()
        .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
drbh's avatar
drbh committed
806

807
808
809
    let typed_grammar = grammar_with_prompt
        .as_ref()
        .map(|(grammar, _)| grammar.clone());
drbh's avatar
drbh committed
810

811
812
813
814
815
816
817
    // apply chat template to flatten the request into a single input
    let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
        Ok(inputs) => inputs,
        Err(err) => {
            metrics::increment_counter!("tgi_request_failure", "err" => "validation");
            tracing::error!("{err}");
            return Err((
drbh's avatar
drbh committed
818
819
                StatusCode::UNPROCESSABLE_ENTITY,
                Json(ErrorResponse {
820
821
                    error: err.to_string(),
                    error_type: err.error_type().to_string(),
drbh's avatar
drbh committed
822
                }),
823
824
            ));
        }
drbh's avatar
drbh committed
825
826
    };

827
828
829
830
831
    // build the request passing some parameters
    let generate_request = GenerateRequest {
        inputs: inputs.to_string(),
        parameters: GenerateParameters {
            best_of: None,
832
            temperature: req.temperature,
833
            repetition_penalty,
834
            frequency_penalty: req.frequency_penalty,
835
            top_k: None,
836
            top_p: req.top_p,
837
838
839
840
            typical_p: None,
            do_sample: true,
            max_new_tokens,
            return_full_text: None,
841
            stop,
842
843
844
            truncate: None,
            watermark: false,
            details: true,
845
            decoder_input_details: !stream,
846
            seed,
847
            top_n_tokens: req.top_logprobs,
848
            grammar: typed_grammar,
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
        },
    };

    // static values that will be returned in all cases
    let model_id = info.model_id.clone();
    let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));

    // switch on stream
    if stream {
        // pass this callback to the stream generation and build the required event structure
        let on_message_callback = move |stream_token: StreamResponse| {
            let event = Event::default();

            let current_time = std::time::SystemTime::now()
                .duration_since(std::time::UNIX_EPOCH)
                .unwrap_or_else(|_| std::time::Duration::from_secs(0))
                .as_secs();

867
868
869
870
            let logprobs = logprobs.then(|| {
                ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens))
            });

drbh's avatar
drbh committed
871
872
873
874
875
876
877
            // replace the content with the tool calls if grammar is present
            let (content, tool_calls) = if tool_grammar.is_some() {
                (None, Some(vec![stream_token.token.text]))
            } else {
                (Some(stream_token.token.text), None)
            };

878
879
880
881
            event
                .json_data(ChatCompletionChunk::new(
                    model_id.clone(),
                    system_fingerprint.clone(),
drbh's avatar
drbh committed
882
883
                    content,
                    tool_calls,
884
                    current_time,
885
                    logprobs,
886
887
888
889
890
891
892
893
894
895
896
                    stream_token.details.map(|d| d.finish_reason.to_string()),
                ))
                .map_or_else(
                    |e| {
                        println!("Failed to serialize ChatCompletionChunk: {:?}", e);
                        Event::default()
                    },
                    |data| data,
                )
        };

897
898
899
900
901
902
903
        let (headers, response_stream) = generate_stream_internal(
            infer,
            compute_type,
            Json(generate_request),
            on_message_callback,
        )
        .await;
904
905
906
        let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
        Ok((headers, sse).into_response())
    } else {
907
908
909
910
911
912
        let (headers, Json(generation)) = generate(
            Extension(infer),
            Extension(compute_type),
            Json(generate_request),
        )
        .await?;
913
914
915
916
917
918

        let current_time = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap_or_else(|_| std::time::Duration::from_secs(0))
            .as_secs();

drbh's avatar
drbh committed
919
920
921
922
923
924
925
926
927
928
929
930
        let (tool_calls, output) = if tool_grammar.is_some() {
            // gen_text should be valid json
            let gen_text_value: Value =
                serde_json::from_str(&generation.generated_text).map_err(|e| {
                    (
                        StatusCode::UNPROCESSABLE_ENTITY,
                        Json(ErrorResponse {
                            error: e.to_string(),
                            error_type: "Input validation error".to_string(),
                        }),
                    )
                })?;
931
            let tool_calls = vec![ToolCall {
drbh's avatar
drbh committed
932
933
934
935
                id: 0,
                r#type: "function".to_string(),
                function: FunctionDefinition {
                    description: None,
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
                    name: gen_text_value
                        .get("function")
                        .and_then(|f| f.get("_name"))
                        .and_then(|name| name.as_str())
                        .unwrap_or("default_function_name")
                        .to_string(),
                    // Serialize the JSON object obtained from "function" to an escaped JSON string
                    arguments: gen_text_value
                        .get("function")
                        .map(|f| {
                            let mut f_cloned = f.clone();
                            if let Value::Object(ref mut props) = f_cloned {
                                props.remove("_name");
                            }
                            f_cloned
                        })
                        .unwrap_or_default(),
drbh's avatar
drbh committed
953
                },
954
955
            }];
            (Some(tool_calls), None)
drbh's avatar
drbh committed
956
957
958
        } else {
            (None, Some(generation.generated_text))
        };
959
960
961
962
        // build the complete response object with the full text
        let response = ChatCompletion::new(
            model_id,
            system_fingerprint,
drbh's avatar
drbh committed
963
            output,
964
965
966
            current_time,
            generation.details.unwrap(),
            logprobs,
drbh's avatar
drbh committed
967
            tool_calls,
968
969
970
971
972
        );

        // wrap generation inside a Vec to match api-inference
        Ok((headers, Json(response)).into_response())
    }
973
974
}

drbh's avatar
drbh committed
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
/// Generate tokens from Vertex request
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/vertex",
    request_body = VertexRequest,
    responses(
    (status = 200, description = "Generated Text", body = VertexResponse),
    (status = 424, description = "Generation Error", body = ErrorResponse,
    example = json ! ({"error": "Request failed during generation"})),
    (status = 429, description = "Model is overloaded", body = ErrorResponse,
    example = json ! ({"error": "Model is overloaded"})),
    (status = 422, description = "Input validation error", body = ErrorResponse,
    example = json ! ({"error": "Input validation error"})),
    (status = 500, description = "Incomplete generation", body = ErrorResponse,
    example = json ! ({"error": "Incomplete generation"})),
    )
    )]
#[instrument(
    skip_all,
    fields(
        total_time,
        validation_time,
        queue_time,
        inference_time,
        time_per_token,
        seed,
    )
)]
async fn vertex_compatibility(
    Extension(infer): Extension<Infer>,
    Extension(compute_type): Extension<ComputeType>,
    Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
    metrics::increment_counter!("tgi_request_count");

    // check that theres at least one instance
    if req.instances.is_empty() {
        return Err((
            StatusCode::UNPROCESSABLE_ENTITY,
            Json(ErrorResponse {
                error: "Input validation error".to_string(),
                error_type: "Input validation error".to_string(),
            }),
        ));
    }

    // Process all instances
    let predictions = req
        .instances
        .iter()
        .map(|instance| {
            let generate_request = GenerateRequest {
                inputs: instance.inputs.clone(),
                parameters: GenerateParameters {
                    do_sample: true,
                    max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
                    seed: instance.parameters.as_ref().and_then(|p| p.seed),
                    details: true,
                    decoder_input_details: true,
                    ..Default::default()
                },
            };

            async {
                generate(
                    Extension(infer.clone()),
                    Extension(compute_type.clone()),
                    Json(generate_request),
                )
                .await
                .map(|(_, Json(generation))| generation.generated_text)
                .map_err(|_| {
                    (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        Json(ErrorResponse {
                            error: "Incomplete generation".into(),
                            error_type: "Incomplete generation".into(),
                        }),
                    )
                })
            }
        })
        .collect::<FuturesUnordered<_>>()
        .try_collect::<Vec<_>>()
        .await?;

    let response = VertexResponse { predictions };
    Ok((HeaderMap::new(), Json(response)).into_response())
}

1066
1067
1068
1069
1070
/// Tokenize inputs
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/tokenize",
1071
    request_body = GenerateRequest,
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
    responses(
    (status = 200, description = "Tokenized ids", body = TokenizeResponse),
    (status = 404, description = "No tokenizer found", body = ErrorResponse,
    example = json ! ({"error": "No fast tokenizer available"})),
    )
    )]
#[instrument(skip_all)]
async fn tokenize(
    Extension(infer): Extension<Infer>,
    Json(req): Json<GenerateRequest>,
1082
) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
    let input = req.inputs.clone();
    let encoding = infer.tokenize(req).await?;
    if let Some(encoding) = encoding {
        let tokens: Vec<SimpleToken> = encoding
            .get_ids()
            .iter()
            .zip(encoding.get_offsets())
            .map(|(&id, &(start, stop))| {
                let text: String = input.chars().skip(start).take(stop - start).collect();
                SimpleToken {
                    id,
                    text,
                    start,
                    stop,
                }
            })
            .collect();
1100
        Ok(Json(TokenizeResponse(tokens)))
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
    } else {
        Err((
            StatusCode::NOT_FOUND,
            Json(ErrorResponse {
                error: "No fast tokenizer or tokenizer.json for this model".to_string(),
                error_type: "no fast tokenizer".to_string(),
            }),
        ))
    }
}

1112
1113
/// Prometheus metrics scrape endpoint
#[utoipa::path(
1114
1115
1116
1117
get,
tag = "Text Generation Inference",
path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String))
1118
1119
1120
1121
1122
)]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
    prom_handle.render()
}

1123
1124
1125
#[derive(Clone, Debug)]
pub(crate) struct ComputeType(String);

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1126
1127
1128
/// Serving method
#[allow(clippy::too_many_arguments)]
pub async fn run(
1129
1130
    model_info: HubModelInfo,
    shard_info: ShardInfo,
1131
    compat_return_full_text: bool,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1132
    max_concurrent_requests: usize,
1133
    max_best_of: usize,
1134
    max_stop_sequences: usize,
Nicolas Patry's avatar
Nicolas Patry committed
1135
    max_top_n_tokens: u32,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1136
    max_input_length: usize,
1137
    max_total_tokens: usize,
1138
    waiting_served_ratio: f32,
1139
    max_batch_prefill_tokens: u32,
1140
    max_batch_total_tokens: u32,
1141
    max_waiting_tokens: usize,
1142
    max_batch_size: Option<usize>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1143
    client: ShardedClient,
1144
    tokenizer: Option<Tokenizer>,
1145
    config: Option<Config>,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1146
1147
    validation_workers: usize,
    addr: SocketAddr,
1148
    allow_origin: Option<AllowOrigin>,
1149
1150
    ngrok: bool,
    ngrok_authtoken: Option<String>,
1151
    ngrok_edge: Option<String>,
1152
    tokenizer_config: HubTokenizerConfig,
1153
    messages_api_enabled: bool,
drbh's avatar
drbh committed
1154
    grammar_support: bool,
1155
) -> Result<(), axum::BoxError> {
1156
1157
1158
    // OpenAPI documentation
    #[derive(OpenApi)]
    #[openapi(
1159
1160
1161
1162
1163
1164
    paths(
    health,
    get_model_info,
    compat_generate,
    generate,
    generate_stream,
1165
    chat_completions,
1166
    completions,
1167
    tokenize,
1168
1169
1170
1171
1172
1173
1174
    metrics,
    ),
    components(
    schemas(
    Info,
    CompatGenerateRequest,
    GenerateRequest,
1175
    GrammarType,
1176
1177
    ChatRequest,
    Message,
1178
    ChatCompletionComplete,
1179
1180
1181
    ChatCompletionChoice,
    ChatCompletionDelta,
    ChatCompletionChunk,
1182
1183
1184
    ChatCompletionLogprob,
    ChatCompletionLogprobs,
    ChatCompletionTopLogprob,
1185
    ChatCompletion,
1186
1187
1188
    CompletionRequest,
    CompletionComplete,
    CompletionCompleteChunk,
1189
1190
1191
1192
    GenerateParameters,
    PrefillToken,
    Token,
    GenerateResponse,
1193
1194
    TokenizeResponse,
    SimpleToken,
1195
1196
1197
1198
1199
1200
    BestOfSequence,
    Details,
    FinishReason,
    StreamResponse,
    StreamDetails,
    ErrorResponse,
drbh's avatar
drbh committed
1201
    GrammarType,
1202
    Usage,
OlivierDehaene's avatar
OlivierDehaene committed
1203
1204
1205
1206
1207
1208
    DeltaToolCall,
    ToolType,
    Tool,
    ToolCall,
    Function,
    FunctionDefinition,
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
    )
    ),
    tags(
    (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
    ),
    info(
    title = "Text Generation Inference",
    license(
    name = "Apache 2.0",
    url = "https://www.apache.org/licenses/LICENSE-2.0"
    )
    )
1221
1222
1223
    )]
    struct ApiDoc;

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1224
    // Create state
1225
1226
1227
    let validation = Validation::new(
        validation_workers,
        tokenizer,
1228
        config,
1229
        max_best_of,
1230
        max_stop_sequences,
Nicolas Patry's avatar
Nicolas Patry committed
1231
        max_top_n_tokens,
1232
1233
        max_input_length,
        max_total_tokens,
drbh's avatar
drbh committed
1234
        grammar_support,
1235
    );
1236
1237
    let generation_health = Arc::new(AtomicBool::new(false));
    let health_ext = Health::new(client.clone(), generation_health.clone());
1238
1239
    let infer = Infer::new(
        client,
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1240
        validation,
1241
        waiting_served_ratio,
1242
        max_batch_prefill_tokens,
1243
        max_batch_total_tokens,
1244
        max_waiting_tokens,
1245
        max_batch_size,
1246
        max_concurrent_requests,
1247
        shard_info.requires_padding,
1248
        shard_info.window_size,
Nicolas Patry's avatar
Nicolas Patry committed
1249
        shard_info.speculate,
1250
        generation_health,
1251
        tokenizer_config,
1252
    );
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1253

1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
    // Duration buckets
    let duration_matcher = Matcher::Suffix(String::from("duration"));
    let n_duration_buckets = 35;
    let mut duration_buckets = Vec::with_capacity(n_duration_buckets);
    // Minimum duration in seconds
    let mut value = 0.0001;
    for _ in 0..n_duration_buckets {
        // geometric sequence
        value *= 1.5;
        duration_buckets.push(value);
    }
    // Input Length buckets
    let input_length_matcher = Matcher::Full(String::from("tgi_request_input_length"));
    let input_length_buckets: Vec<f64> = (0..100)
        .map(|x| (max_input_length as f64 / 100.0) * (x + 1) as f64)
        .collect();
    // Generated tokens buckets
    let generated_tokens_matcher = Matcher::Full(String::from("tgi_request_generated_tokens"));
    let generated_tokens_buckets: Vec<f64> = (0..100)
        .map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64)
        .collect();
    // Input Length buckets
    let max_new_tokens_matcher = Matcher::Full(String::from("tgi_request_max_new_tokens"));
    let max_new_tokens_buckets: Vec<f64> = (0..100)
        .map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64)
        .collect();
    // Batch size buckets
    let batch_size_matcher = Matcher::Full(String::from("tgi_batch_next_size"));
1282
    let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();
OlivierDehaene's avatar
OlivierDehaene committed
1283
1284
1285
    // Speculated tokens buckets
    let skipped_matcher = Matcher::Full(String::from("tgi_request_skipped_tokens"));
    let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();
1286

1287
    // Prometheus handler
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
    let builder = PrometheusBuilder::new()
        .set_buckets_for_metric(duration_matcher, &duration_buckets)
        .unwrap()
        .set_buckets_for_metric(input_length_matcher, &input_length_buckets)
        .unwrap()
        .set_buckets_for_metric(generated_tokens_matcher, &generated_tokens_buckets)
        .unwrap()
        .set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)
        .unwrap()
        .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)
OlivierDehaene's avatar
OlivierDehaene committed
1298
1299
        .unwrap()
        .set_buckets_for_metric(skipped_matcher, &skipped_buckets)
1300
        .unwrap();
1301
1302
1303
1304
    let prom_handle = builder
        .install_recorder()
        .expect("failed to install metrics recorder");

1305
1306
1307
1308
1309
1310
1311
    // CORS layer
    let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
    let cors_layer = CorsLayer::new()
        .allow_methods([Method::GET, Method::POST])
        .allow_headers([http::header::CONTENT_TYPE])
        .allow_origin(allow_origin);

1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
    // Endpoint info
    let info = Info {
        model_id: model_info.model_id,
        model_sha: model_info.sha,
        model_dtype: shard_info.dtype,
        model_device_type: shard_info.device_type,
        model_pipeline_tag: model_info.pipeline_tag,
        max_concurrent_requests,
        max_best_of,
        max_stop_sequences,
        max_input_length,
        max_total_tokens,
        waiting_served_ratio,
        max_batch_total_tokens,
        max_waiting_tokens,
1327
        max_batch_size,
1328
1329
1330
        validation_workers,
        version: env!("CARGO_PKG_VERSION"),
        sha: option_env!("VERGEN_GIT_SHA"),
1331
        docker_label: option_env!("DOCKER_LABEL"),
1332
1333
    };

drbh's avatar
drbh committed
1334
1335
1336
1337
1338
    // Define VertextApiDoc conditionally only if the "google" feature is enabled
    let doc = {
        // avoid `mut` if possible
        #[cfg(feature = "google")]
        {
1339
1340
1341
1342
1343
1344
1345
1346
1347
            use crate::VertexInstance;

            #[derive(OpenApi)]
            #[openapi(
                paths(vertex_compatibility),
                components(schemas(VertexInstance, VertexRequest, VertexResponse))
            )]
            struct VertextApiDoc;

drbh's avatar
drbh committed
1348
            // limiting mutability to the smallest scope necessary
1349
            let mut doc = ApiDoc::openapi();
drbh's avatar
drbh committed
1350
1351
1352
1353
1354
1355
1356
            doc.merge(VertextApiDoc::openapi());
            doc
        }
        #[cfg(not(feature = "google"))]
        ApiDoc::openapi()
    };

1357
    // Configure Swagger UI
drbh's avatar
drbh committed
1358
    let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
1359
1360
1361

    // Define base and health routes
    let base_routes = Router::new()
1362
        .route("/", post(compat_generate))
1363
        .route("/", get(health))
1364
        .route("/info", get(get_model_info))
Olivier Dehaene's avatar
Olivier Dehaene committed
1365
        .route("/generate", post(generate))
1366
        .route("/generate_stream", post(generate_stream))
1367
        .route("/v1/chat/completions", post(chat_completions))
1368
        .route("/v1/completions", post(completions))
drbh's avatar
drbh committed
1369
        .route("/vertex", post(vertex_compatibility))
1370
        .route("/tokenize", post(tokenize))
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1371
        .route("/health", get(health))
1372
        .route("/ping", get(health))
1373
1374
1375
        .route("/metrics", get(metrics));

    // Conditional AWS Sagemaker route
1376
    let aws_sagemaker_route = if messages_api_enabled {
1377
1378
1379
1380
1381
        Router::new().route("/invocations", post(chat_completions)) // Use 'chat_completions' for OAI_ENABLED
    } else {
        Router::new().route("/invocations", post(compat_generate)) // Use 'compat_generate' otherwise
    };

1382
1383
    let compute_type =
        ComputeType(std::env::var("COMPUTE_TYPE").unwrap_or("gpu+optimized".to_string()));
1384

1385
    // Combine routes and layers
drbh's avatar
drbh committed
1386
    let mut app = Router::new()
1387
1388
        .merge(swagger_ui)
        .merge(base_routes)
drbh's avatar
drbh committed
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
        .merge(aws_sagemaker_route);

    #[cfg(feature = "google")]
    {
        tracing::info!("Built with `google` feature");
        tracing::info!(
            "Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
        );
        if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
            app = app.route(&env_predict_route, post(vertex_compatibility));
        }
        if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") {
            app = app.route(&env_health_route, get(health));
        }
    }

    // add layers after routes
    app = app
1407
        .layer(Extension(info))
1408
        .layer(Extension(health_ext.clone()))
1409
1410
        .layer(Extension(compat_return_full_text))
        .layer(Extension(infer))
1411
        .layer(Extension(compute_type))
1412
        .layer(Extension(prom_handle.clone()))
Nicolas Patry's avatar
Nicolas Patry committed
1413
        .layer(OtelAxumLayer::default())
1414
        .layer(cors_layer);
Olivier Dehaene's avatar
Olivier Dehaene committed
1415

1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
    if ngrok {
        #[cfg(feature = "ngrok")]
        {
            use ngrok::config::TunnelBuilder;

            let _ = addr;

            let authtoken =
                ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling");

1426
1427
1428
            let edge = ngrok_edge.expect("`ngrok-edge` must be set when using ngrok tunneling");

            let tunnel = ngrok::Session::builder()
1429
1430
1431
1432
                .authtoken(authtoken)
                .connect()
                .await
                .unwrap()
1433
1434
                .labeled_tunnel()
                .label("edge", edge);
1435
1436

            let listener = tunnel.listen().await.unwrap();
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451

            // Run prom metrics and health locally too
            tokio::spawn(
                axum::Server::bind(&addr)
                    .serve(
                        Router::new()
                            .route("/health", get(health))
                            .route("/metrics", get(metrics))
                            .layer(Extension(health_ext))
                            .layer(Extension(prom_handle))
                            .into_make_service(),
                    )
                    //Wait until all requests are finished to shut down
                    .with_graceful_shutdown(shutdown_signal()),
            );
1452
1453
1454
1455
1456
1457

            // Run server
            axum::Server::builder(listener)
                .serve(app.into_make_service())
                //Wait until all requests are finished to shut down
                .with_graceful_shutdown(shutdown_signal())
1458
                .await?;
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
        }
        #[cfg(not(feature = "ngrok"))]
        {
            let _ngrok_authtoken = ngrok_authtoken;
            let _ngrok_domain = ngrok_domain;
            let _ngrok_username = ngrok_username;
            let _ngrok_password = ngrok_password;

            panic!("`text-generation-router` was compiled without the `ngrok` feature");
        }
    } else {
        // Run server
        axum::Server::bind(&addr)
            .serve(app.into_make_service())
            // Wait until all requests are finished to shut down
            .with_graceful_shutdown(shutdown_signal())
1475
            .await?;
1476
    }
1477
    Ok(())
Olivier Dehaene's avatar
Olivier Dehaene committed
1478
}
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504

/// Shutdown signal handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {},
        _ = terminate => {},
    }

    tracing::info!("signal received, starting graceful shutdown");
1505
    opentelemetry::global::shutdown_tracer_provider();
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
1506
}
1507

1508
1509
impl From<i32> for FinishReason {
    fn from(finish_reason: i32) -> Self {
Nicolas Patry's avatar
Nicolas Patry committed
1510
        let finish_reason = text_generation_client::FinishReason::try_from(finish_reason).unwrap();
1511
1512
1513
1514
1515
1516
1517
1518
        match finish_reason {
            text_generation_client::FinishReason::Length => FinishReason::Length,
            text_generation_client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,
            text_generation_client::FinishReason::StopSequence => FinishReason::StopSequence,
        }
    }
}

1519
1520
1521
1522
1523
1524
1525
1526
/// Convert to Axum supported formats
impl From<InferError> for (StatusCode, Json<ErrorResponse>) {
    fn from(err: InferError) -> Self {
        let status_code = match err {
            InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,
            InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,
            InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,
            InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,
1527
            InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,
1528
            InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,
1529
1530
1531
1532
1533
1534
        };

        (
            status_code,
            Json(ErrorResponse {
                error: err.to_string(),
1535
                error_type: err.error_type().to_string(),
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
            }),
        )
    }
}

impl From<InferError> for Event {
    fn from(err: InferError) -> Self {
        Event::default()
            .json_data(ErrorResponse {
                error: err.to_string(),
1546
                error_type: err.error_type().to_string(),
1547
1548
1549
1550
            })
            .unwrap()
    }
}