openai.rs 60 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
6
7
8
9
use std::{
    collections::HashSet,
    sync::Arc,
    time::{SystemTime, UNIX_EPOCH},
};

10
use axum::{
11
    Json, Router,
12
    extract::State,
Ryan Olson's avatar
Ryan Olson committed
13
    http::{HeaderMap, StatusCode},
14
15
    response::{
        IntoResponse, Response,
16
        sse::{Event, KeepAlive, Sse},
17
18
19
    },
    routing::{get, post},
};
Ryan Olson's avatar
Ryan Olson committed
20
21
22
23
use dynamo_runtime::{
    pipeline::{AsyncEngineContextProvider, Context},
    protocols::annotated::AnnotationsProvider,
};
24
use futures::{StreamExt, stream};
25
26
27
use serde::{Deserialize, Serialize};

use super::{
28
29
    RouteDoc,
    disconnect::{ConnectionHandle, create_connection_monitor, monitor_for_disconnects},
30
    error::HttpError,
Ryan Olson's avatar
Ryan Olson committed
31
    metrics::{Endpoint, ResponseMetricCollector},
32
    service_v2,
33
};
34
use crate::engines::ValidateRequest;
35
use crate::protocols::openai::chat_completions::aggregator::ChatCompletionAggregator;
36
use crate::protocols::openai::{
37
    ParsingOptions,
38
39
40
41
    chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionResponse},
    completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
    embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
    responses::{NvCreateResponse, NvResponse},
42
};
43
use crate::request_template::RequestTemplate;
44
use crate::types::Annotated;
GuanLuo's avatar
GuanLuo committed
45
use crate::{discovery::ModelManager, preprocessor::LLMMetricAnnotation};
46
47
use dynamo_runtime::logging::get_distributed_tracing_context;
use tracing::Instrument;
48

Ryan Olson's avatar
Ryan Olson committed
49
50
51
52
53
pub const DYNAMO_REQUEST_ID_HEADER: &str = "x-dynamo-request-id";

/// Dynamo Annotation for the request ID
pub const ANNOTATION_REQUEST_ID: &str = "request_id";

54
55
56
57
58
59
60
61
62
63
64
// Default axum max body limit without configuring is 2MB: https://docs.rs/axum/latest/axum/extract/struct.DefaultBodyLimit.html
/// Default body limit in bytes (45MB) to support 500k+ token payloads.
/// Can be configured at compile time using the DYN_FRONTEND_BODY_LIMIT_MB environment variable
fn get_body_limit() -> usize {
    std::env::var("DYN_HTTP_BODY_LIMIT_MB")
        .ok()
        .and_then(|s| s.parse::<usize>().ok())
        .map(|mb| mb * 1024 * 1024)
        .unwrap_or(45 * 1024 * 1024)
}

Ryan Olson's avatar
Ryan Olson committed
65
66
pub type ErrorResponse = (StatusCode, Json<ErrorMessage>);

67
#[derive(Serialize, Deserialize)]
Ryan Olson's avatar
Ryan Olson committed
68
pub(crate) struct ErrorMessage {
69
70
71
    error: String,
}

Ryan Olson's avatar
Ryan Olson committed
72
impl ErrorMessage {
73
    /// Not Found Error
Ryan Olson's avatar
Ryan Olson committed
74
    pub fn model_not_found() -> ErrorResponse {
75
76
        (
            StatusCode::NOT_FOUND,
Ryan Olson's avatar
Ryan Olson committed
77
            Json(ErrorMessage {
78
79
80
81
82
83
84
                error: "Model not found".to_string(),
            }),
        )
    }

    /// Service Unavailable
    /// This is returned when the service is live, but not ready.
Ryan Olson's avatar
Ryan Olson committed
85
    pub fn _service_unavailable() -> ErrorResponse {
86
87
        (
            StatusCode::SERVICE_UNAVAILABLE,
Ryan Olson's avatar
Ryan Olson committed
88
            Json(ErrorMessage {
89
90
91
92
93
94
95
96
97
                error: "Service is not ready".to_string(),
            }),
        )
    }

    /// Internal Service Error
    /// Return this error when the service encounters an internal error.
    /// We should return a generic message to the client instead of the real error.
    /// Internal Services errors are the result of misconfiguration or bugs in the service.
Ryan Olson's avatar
Ryan Olson committed
98
    pub fn internal_server_error(msg: &str) -> ErrorResponse {
99
100
101
        tracing::error!("Internal server error: {msg}");
        (
            StatusCode::INTERNAL_SERVER_ERROR,
Ryan Olson's avatar
Ryan Olson committed
102
            Json(ErrorMessage {
103
104
105
106
107
                error: msg.to_string(),
            }),
        )
    }

108
109
110
    /// Not Implemented Error
    /// Return this error when the client requests a feature that is not yet implemented.
    /// This should be used for features that are planned but not available.
Ryan Olson's avatar
Ryan Olson committed
111
    pub fn not_implemented_error(msg: &str) -> ErrorResponse {
112
113
114
        tracing::error!("Not Implemented error: {msg}");
        (
            StatusCode::NOT_IMPLEMENTED,
Ryan Olson's avatar
Ryan Olson committed
115
            Json(ErrorMessage {
116
117
118
119
120
                error: msg.to_string(),
            }),
        )
    }

Neelay Shah's avatar
Neelay Shah committed
121
    /// The OAI endpoints call an [`dynamo.runtime::engine::AsyncEngine`] which are specialized to return
122
    /// an [`anyhow::Error`]. This method will convert the [`anyhow::Error`] into an [`HttpError`].
Ryan Olson's avatar
Ryan Olson committed
123
    /// If successful, it will return the [`HttpError`] as an [`ErrorMessage::internal_server_error`]
124
    /// with the details of the error.
Ryan Olson's avatar
Ryan Olson committed
125
    pub fn from_anyhow(err: anyhow::Error, alt_msg: &str) -> ErrorResponse {
126
127
128
        // First check for PipelineError::ServiceOverloaded
        if let Some(pipeline_err) =
            err.downcast_ref::<dynamo_runtime::pipeline::error::PipelineError>()
129
            && matches!(
130
131
                pipeline_err,
                dynamo_runtime::pipeline::error::PipelineError::ServiceOverloaded(_)
132
133
134
135
136
137
138
139
            )
        {
            return (
                StatusCode::SERVICE_UNAVAILABLE,
                Json(ErrorMessage {
                    error: pipeline_err.to_string(),
                }),
            );
140
141
142
        }

        // Then check for HttpError
143
        match err.downcast::<HttpError>() {
Ryan Olson's avatar
Ryan Olson committed
144
145
            Ok(http_error) => ErrorMessage::from_http_error(http_error),
            Err(err) => ErrorMessage::internal_server_error(&format!("{alt_msg}: {err}")),
146
147
148
149
        }
    }

    /// Implementers should only be able to throw 400-499 errors.
Ryan Olson's avatar
Ryan Olson committed
150
    pub fn from_http_error(err: HttpError) -> ErrorResponse {
151
        if err.code < 400 || err.code >= 500 {
Ryan Olson's avatar
Ryan Olson committed
152
            return ErrorMessage::internal_server_error(&err.message);
153
154
        }
        match StatusCode::from_u16(err.code) {
Ryan Olson's avatar
Ryan Olson committed
155
156
            Ok(code) => (code, Json(ErrorMessage { error: err.message })),
            Err(_) => ErrorMessage::internal_server_error(&err.message),
157
158
159
160
        }
    }
}

Ryan Olson's avatar
Ryan Olson committed
161
impl From<HttpError> for ErrorMessage {
162
    fn from(err: HttpError) -> Self {
Ryan Olson's avatar
Ryan Olson committed
163
        ErrorMessage { error: err.message }
164
165
166
    }
}

Ryan Olson's avatar
Ryan Olson committed
167
168
/// Get the request ID from a primary source, or next from the headers, or lastly create a new one if not present
fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> String {
169
    // Try to get request id from trace context
170
171
172
173
    if let Some(trace_context) = get_distributed_tracing_context()
        && let Some(x_dynamo_request_id) = trace_context.x_dynamo_request_id
    {
        return x_dynamo_request_id;
174
175
    }

Ryan Olson's avatar
Ryan Olson committed
176
    // Try to get the request ID from the primary source
177
178
179
180
    if let Some(primary) = primary
        && let Ok(uuid) = uuid::Uuid::parse_str(primary)
    {
        return uuid.to_string();
Ryan Olson's avatar
Ryan Olson committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    }

    // Try to get the request ID header as a string slice
    let request_id_opt = headers
        .get(DYNAMO_REQUEST_ID_HEADER)
        .and_then(|h| h.to_str().ok());

    // Try to parse the request ID as a UUID, or generate a new one if missing/invalid
    let uuid = match request_id_opt {
        Some(request_id) => {
            uuid::Uuid::parse_str(request_id).unwrap_or_else(|_| uuid::Uuid::new_v4())
        }
        None => uuid::Uuid::new_v4(),
    };

    uuid.to_string()
}

GuanLuo's avatar
GuanLuo committed
199
200
fn get_parsing_options(manager: &ModelManager, model: &str) -> ParsingOptions {
    let tool_call_parser = manager.get_model_tool_call_parser(model);
201
202
203
204
205
    let reasoning_parser = None; // TODO: Implement reasoning parser

    ParsingOptions::new(tool_call_parser, reasoning_parser)
}

206
207
208
209
210
211
212
213
/// OpenAI Completions Request Handler
///
/// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source"
/// for an [`super::OpenAICompletionsStreamingEngine`] and will return a stream of
/// responses which will be forward to the client.
///
/// Note: For all requests, streaming or non-streaming, we always call the engine with streaming enabled. For
/// non-streaming requests, we will fold the stream into a single response as part of this handler.
Ryan Olson's avatar
Ryan Olson committed
214
async fn handler_completions(
215
    State(state): State<Arc<service_v2::State>>,
Ryan Olson's avatar
Ryan Olson committed
216
    headers: HeaderMap,
217
    Json(request): Json<NvCreateCompletionRequest>,
Ryan Olson's avatar
Ryan Olson committed
218
219
220
221
222
223
224
225
226
227
) -> Result<Response, ErrorResponse> {
    // return a 503 if the service is not ready
    check_ready(&state)?;

    // create the context for the request
    let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
    let request = Context::with_id(request, request_id);
    let context = request.context();

    // create the connection handles
228
229
    let (mut connection_handle, stream_handle) =
        create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
Ryan Olson's avatar
Ryan Olson committed
230
231
232

    // possibly long running task
    // if this returns a streaming response, the stream handle will be armed and captured by the response stream
233
    let response = tokio::spawn(completions(state, request, stream_handle).in_current_span())
Ryan Olson's avatar
Ryan Olson committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        .await
        .map_err(|e| {
            ErrorMessage::internal_server_error(&format!(
                "Failed to await chat completions task: {:?}",
                e,
            ))
        })?;

    // if we got here, then we will return a response and the potentially long running task has completed successfully
    // without need to be cancelled.
    connection_handle.disarm();

    response
}

#[tracing::instrument(skip_all)]
async fn completions(
    state: Arc<service_v2::State>,
    request: Context<NvCreateCompletionRequest>,
    stream_handle: ConnectionHandle,
) -> Result<Response, ErrorResponse> {
255
256
257
    // return a 503 if the service is not ready
    check_ready(&state)?;

258
259
    validate_completion_fields_generic(&request)?;

260
    let request_id = request.id().to_string();
261
262

    // todo - decide on default
263
    let streaming = request.inner.stream.unwrap_or(false);
264
265

    // update the request to always stream
Ryan Olson's avatar
Ryan Olson committed
266
267
268
269
    let request = request.map(|mut req| {
        req.inner.stream = Some(true);
        req
    });
270

271
272
    // todo - make the protocols be optional for model name
    // todo - when optional, if none, apply a default
273
    let model = &request.inner.model;
274
275
276

    // todo - error handling should be more robust
    let engine = state
277
        .manager()
278
        .get_completions_engine(model)
Ryan Olson's avatar
Ryan Olson committed
279
        .map_err(|_| ErrorMessage::model_not_found())?;
280

GuanLuo's avatar
GuanLuo committed
281
    let parsing_options = get_parsing_options(state.manager(), model);
282

283
    let mut inflight_guard =
284
285
286
        state
            .metrics_clone()
            .create_inflight_guard(model, Endpoint::Completions, streaming);
287

288
289
    let mut response_collector = state.metrics_clone().create_response_collector(model);

Ryan Olson's avatar
Ryan Olson committed
290
291
    // prepare to process any annotations
    let annotations = request.annotations();
292
293
294
295
296

    // issue the generate call on the engine
    let stream = engine
        .generate(request)
        .await
Ryan Olson's avatar
Ryan Olson committed
297
        .map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;
298
299
300
301

    // capture the context to cancel the stream if the client disconnects
    let ctx = stream.context();

Ryan Olson's avatar
Ryan Olson committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    let annotations = annotations.map_or(Vec::new(), |annotations| {
        annotations
            .iter()
            .filter_map(|annotation| {
                if annotation == ANNOTATION_REQUEST_ID {
                    Annotated::<NvCreateCompletionResponse>::from_annotation(
                        ANNOTATION_REQUEST_ID,
                        &request_id,
                    )
                    .ok()
                } else {
                    None
                }
            })
            .collect::<Vec<_>>()
    });

    // apply any annotations to the front of the stream
    let stream = stream::iter(annotations).chain(stream);
321
322

    if streaming {
323
324
325
        let stream = stream.map(move |response| {
            process_event_converter(EventConverter::from(response), &mut response_collector)
        });
Ryan Olson's avatar
Ryan Olson committed
326
        let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
327

328
329
        let mut sse_stream = Sse::new(stream);

330
        if let Some(keep_alive) = state.sse_keep_alive() {
331
332
333
334
            sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
        }

        Ok(sse_stream.into_response())
335
    } else {
336
337
338
339
340
        // Tap the stream to collect metrics for non-streaming requests without altering items
        let stream = stream.inspect(move |response| {
            process_metrics_only(response, &mut response_collector);
        });

341
        let response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
342
343
344
345
346
347
348
            .await
            .map_err(|e| {
                tracing::error!(
                    "Failed to fold completions stream for {}: {:?}",
                    request_id,
                    e
                );
Ryan Olson's avatar
Ryan Olson committed
349
                ErrorMessage::internal_server_error("Failed to fold completions stream")
350
351
            })?;

352
        inflight_guard.mark_ok();
353
354
355
356
        Ok(Json(response).into_response())
    }
}

357
358
#[tracing::instrument(skip_all)]
async fn embeddings(
359
    State(state): State<Arc<service_v2::State>>,
360
    headers: HeaderMap,
361
    Json(request): Json<NvCreateEmbeddingRequest>,
Ryan Olson's avatar
Ryan Olson committed
362
) -> Result<Response, ErrorResponse> {
363
364
365
    // return a 503 if the service is not ready
    check_ready(&state)?;

366
367
368
    let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
    let request = Context::with_id(request, request_id);
    let request_id = request.id().to_string();
369
370
371
372
373
374
375
376
377
378
379
380

    // Embeddings are typically not streamed, so we default to non-streaming
    let streaming = false;

    // todo - make the protocols be optional for model name
    // todo - when optional, if none, apply a default
    let model = &request.inner.model;

    // todo - error handling should be more robust
    let engine = state
        .manager()
        .get_embeddings_engine(model)
Ryan Olson's avatar
Ryan Olson committed
381
        .map_err(|_| ErrorMessage::model_not_found())?;
382
383
384
385
386
387
388
389
390
391
392

    // this will increment the inflight gauge for the model
    let mut inflight =
        state
            .metrics_clone()
            .create_inflight_guard(model, Endpoint::Embeddings, streaming);

    // issue the generate call on the engine
    let stream = engine
        .generate(request)
        .await
Ryan Olson's avatar
Ryan Olson committed
393
        .map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate embeddings"))?;
394
395
396

    // Embeddings are typically returned as a single response (non-streaming)
    // so we fold the stream into a single response
Ryan Olson's avatar
Ryan Olson committed
397
    let response = NvCreateEmbeddingResponse::from_annotated_stream(stream)
398
399
400
401
402
403
404
        .await
        .map_err(|e| {
            tracing::error!(
                "Failed to fold embeddings stream for {}: {:?}",
                request_id,
                e
            );
Ryan Olson's avatar
Ryan Olson committed
405
            ErrorMessage::internal_server_error("Failed to fold embeddings stream")
406
407
408
409
        })?;

    inflight.mark_ok();
    Ok(Json(response).into_response())
410
411
}

Ryan Olson's avatar
Ryan Olson committed
412
413
414
415
416
417
418
419
420
421
422
423
424
425
async fn handler_chat_completions(
    State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
    headers: HeaderMap,
    Json(request): Json<NvCreateChatCompletionRequest>,
) -> Result<Response, ErrorResponse> {
    // return a 503 if the service is not ready
    check_ready(&state)?;

    // create the context for the request
    let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
    let request = Context::with_id(request, request_id);
    let context = request.context();

    // create the connection handles
426
427
    let (mut connection_handle, stream_handle) =
        create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
Ryan Olson's avatar
Ryan Olson committed
428

429
430
431
432
433
434
435
436
437
    let response =
        tokio::spawn(chat_completions(state, template, request, stream_handle).in_current_span())
            .await
            .map_err(|e| {
                ErrorMessage::internal_server_error(&format!(
                    "Failed to await chat completions task: {:?}",
                    e,
                ))
            })?;
Ryan Olson's avatar
Ryan Olson committed
438
439
440
441
442
443
444
445

    // if we got here, then we will return a response and the potentially long running task has completed successfully
    // without need to be cancelled.
    connection_handle.disarm();

    response
}

446
447
448
449
450
451
452
453
454
/// OpenAI Chat Completions Request Handler
///
/// This method will handle the incoming request for the /v1/chat/completions endpoint. The endpoint is a "source"
/// for an [`super::OpenAIChatCompletionsStreamingEngine`] and will return a stream of responses which will be
/// forward to the client.
///
/// Note: For all requests, streaming or non-streaming, we always call the engine with streaming enabled. For
/// non-streaming requests, we will fold the stream into a single response as part of this handler.
async fn chat_completions(
Ryan Olson's avatar
Ryan Olson committed
455
456
457
458
459
    state: Arc<service_v2::State>,
    template: Option<RequestTemplate>,
    mut request: Context<NvCreateChatCompletionRequest>,
    mut stream_handle: ConnectionHandle,
) -> Result<Response, ErrorResponse> {
460
461
462
    // return a 503 if the service is not ready
    check_ready(&state)?;

Ryan Olson's avatar
Ryan Olson committed
463
464
    let request_id = request.id().to_string();

465
466
467
468
    // Handle unsupported fields - if Some(resp) is returned by
    // validate_chat_completion_unsupported_fields,
    // then a field was used that is unsupported. We will log an error message
    // and early return a 501 NOT_IMPLEMENTED status code. Otherwise, proceeed.
Ryan Olson's avatar
Ryan Olson committed
469
    validate_chat_completion_unsupported_fields(&request)?;
470

471
472
473
    // Handle required fields like messages shouldn't be empty.
    validate_chat_completion_required_fields(&request)?;

474
475
476
    // Handle Rest of Validation Errors
    validate_chat_completion_fields_generic(&request)?;

477
478
479
480
481
482
483
484
485
486
487
488
    // Apply template values if present
    if let Some(template) = template {
        if request.inner.model.is_empty() {
            request.inner.model = template.model.clone();
        }
        if request.inner.temperature.unwrap_or(0.0) == 0.0 {
            request.inner.temperature = Some(template.temperature);
        }
        if request.inner.max_completion_tokens.unwrap_or(0) == 0 {
            request.inner.max_completion_tokens = Some(template.max_completion_tokens);
        }
    }
Ryan Olson's avatar
Ryan Olson committed
489
    tracing::trace!("Received chat completions request: {:?}", request.content());
490
491

    // todo - decide on default
Paul Hendricks's avatar
Paul Hendricks committed
492
    let streaming = request.inner.stream.unwrap_or(false);
493
494

    // update the request to always stream
Ryan Olson's avatar
Ryan Olson committed
495
496
497
498
    let request = request.map(|mut req| {
        req.inner.stream = Some(true);
        req
    });
499
500
501

    // todo - make the protocols be optional for model name
    // todo - when optional, if none, apply a default
Paul Hendricks's avatar
Paul Hendricks committed
502
    let model = &request.inner.model;
503
504
505
506
507

    // todo - determine the proper error code for when a request model is not present
    tracing::trace!("Getting chat completions engine for model: {}", model);

    let engine = state
508
        .manager()
509
        .get_chat_completions_engine(model)
Ryan Olson's avatar
Ryan Olson committed
510
        .map_err(|_| ErrorMessage::model_not_found())?;
511

GuanLuo's avatar
GuanLuo committed
512
    let parsing_options = get_parsing_options(state.manager(), model);
513

514
    let mut inflight_guard =
515
516
517
        state
            .metrics_clone()
            .create_inflight_guard(model, Endpoint::ChatCompletions, streaming);
518

519
520
    let mut response_collector = state.metrics_clone().create_response_collector(model);

521
    tracing::trace!("Issuing generate call for chat completions");
Ryan Olson's avatar
Ryan Olson committed
522
    let annotations = request.annotations();
523
524
525
526
527

    // issue the generate call on the engine
    let stream = engine
        .generate(request)
        .await
Ryan Olson's avatar
Ryan Olson committed
528
        .map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;
529
530
531
532

    // capture the context to cancel the stream if the client disconnects
    let ctx = stream.context();

Ryan Olson's avatar
Ryan Olson committed
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
    // prepare any requested annotations
    let annotations = annotations.map_or(Vec::new(), |annotations| {
        annotations
            .iter()
            .filter_map(|annotation| {
                if annotation == ANNOTATION_REQUEST_ID {
                    Annotated::from_annotation(ANNOTATION_REQUEST_ID, &request_id).ok()
                } else {
                    None
                }
            })
            .collect::<Vec<_>>()
    });

    // apply any annotations to the front of the stream
    let stream = stream::iter(annotations).chain(stream);

550
551
552
553
    // todo - tap the stream and propagate request level metrics
    // note - we might do this as part of the post processing set to make it more generic

    if streaming {
Ryan Olson's avatar
Ryan Olson committed
554
555
        stream_handle.arm();

556
557
558
        let stream = stream.map(move |response| {
            process_event_converter(EventConverter::from(response), &mut response_collector)
        });
Ryan Olson's avatar
Ryan Olson committed
559
        let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
560

561
562
        let mut sse_stream = Sse::new(stream);

563
        if let Some(keep_alive) = state.sse_keep_alive() {
564
565
566
567
            sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
        }

        Ok(sse_stream.into_response())
568
    } else {
569
570
571
572
        let stream = stream.inspect(move |response| {
            process_metrics_only(response, &mut response_collector);
        });

573
574
575
576
577
578
579
580
581
582
583
584
585
586
        let response =
            NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
                .await
                .map_err(|e| {
                    tracing::error!(
                        request_id,
                        "Failed to fold chat completions stream for: {:?}",
                        e
                    );
                    ErrorMessage::internal_server_error(&format!(
                        "Failed to fold chat completions stream: {}",
                        e
                    ))
                })?;
587

588
        inflight_guard.mark_ok();
589
590
591
592
        Ok(Json(response).into_response())
    }
}

593
594
595
596
597
/// Checks for unsupported fields in the request.
/// Returns Some(response) if unsupported fields are present.
#[allow(deprecated)]
pub fn validate_chat_completion_unsupported_fields(
    request: &NvCreateChatCompletionRequest,
Ryan Olson's avatar
Ryan Olson committed
598
) -> Result<(), ErrorResponse> {
599
600
601
    let inner = &request.inner;

    if inner.parallel_tool_calls == Some(true) {
Ryan Olson's avatar
Ryan Olson committed
602
        return Err(ErrorMessage::not_implemented_error(
603
604
605
606
607
            "`parallel_tool_calls: true` is not supported.",
        ));
    }

    if inner.function_call.is_some() {
Ryan Olson's avatar
Ryan Olson committed
608
        return Err(ErrorMessage::not_implemented_error(
609
610
611
612
613
            "`function_call` is deprecated. Please migrate to use `tool_choice` instead.",
        ));
    }

    if inner.functions.is_some() {
Ryan Olson's avatar
Ryan Olson committed
614
        return Err(ErrorMessage::not_implemented_error(
615
616
617
618
            "`functions` is deprecated. Please migrate to use `tools` instead.",
        ));
    }

Ryan Olson's avatar
Ryan Olson committed
619
    Ok(())
620
621
}

622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
/// Validates that required fields are present and valid in the chat completion request
pub fn validate_chat_completion_required_fields(
    request: &NvCreateChatCompletionRequest,
) -> Result<(), ErrorResponse> {
    let inner = &request.inner;

    if inner.messages.is_empty() {
        return Err(ErrorMessage::from_http_error(HttpError {
            code: 400,
            message: "The 'messages' field cannot be empty. At least one message is required."
                .to_string(),
        }));
    }

    Ok(())
}

639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
/// Validates a chat completion request and returns an error response if validation fails.
///
/// This function calls the `validate` method implemented for `NvCreateChatCompletionRequest`.
/// If validation fails, it maps the error into an OpenAI-compatible error response.
pub fn validate_chat_completion_fields_generic(
    request: &NvCreateChatCompletionRequest,
) -> Result<(), ErrorResponse> {
    request.validate().map_err(|e| {
        ErrorMessage::from_http_error(HttpError {
            code: 400,
            message: e.to_string(),
        })
    })
}

/// Validates a completion request and returns an error response if validation fails.
///
/// This function calls the `validate` method implemented for `NvCreateCompletionRequest`.
/// If validation fails, it maps the error into an OpenAI-compatible error response.
pub fn validate_completion_fields_generic(
    request: &NvCreateCompletionRequest,
) -> Result<(), ErrorResponse> {
    request.validate().map_err(|e| {
        ErrorMessage::from_http_error(HttpError {
            code: 400,
            message: e.to_string(),
        })
    })
}

669
670
671
/// OpenAI Responses Request Handler
///
/// This method will handle the incoming request for the /v1/responses endpoint.
Ryan Olson's avatar
Ryan Olson committed
672
async fn handler_responses(
673
    State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
Ryan Olson's avatar
Ryan Olson committed
674
675
676
677
678
679
680
681
682
683
684
685
    headers: HeaderMap,
    Json(request): Json<NvCreateResponse>,
) -> Result<Response, ErrorResponse> {
    // return a 503 if the service is not ready
    check_ready(&state)?;

    // create the context for the request
    let request_id = get_or_create_request_id(request.inner.user.as_deref(), &headers);
    let request = Context::with_id(request, request_id);
    let context = request.context();

    // create the connection handles
686
687
    let (mut connection_handle, _stream_handle) =
        create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
Ryan Olson's avatar
Ryan Olson committed
688

689
    let response = tokio::spawn(responses(state, template, request).in_current_span())
Ryan Olson's avatar
Ryan Olson committed
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
        .await
        .map_err(|e| {
            ErrorMessage::internal_server_error(&format!(
                "Failed to await chat completions task: {:?}",
                e,
            ))
        })?;

    // if we got here, then we will return a response and the potentially long running task has completed successfully
    // without need to be cancelled.
    connection_handle.disarm();

    response
}

#[tracing::instrument(level = "debug", skip_all, fields(request_id = %request.id()))]
async fn responses(
    state: Arc<service_v2::State>,
    template: Option<RequestTemplate>,
    mut request: Context<NvCreateResponse>,
) -> Result<Response, ErrorResponse> {
711
712
713
714
715
716
    // return a 503 if the service is not ready
    check_ready(&state)?;

    // Handle unsupported fields - if Some(resp) is returned by validate_unsupported_fields,
    // then a field was used that is unsupported. We will log an error message
    // and early return a 501 NOT_IMPLEMENTED status code. Otherwise, proceeed.
717
    if let Some(resp) = validate_response_unsupported_fields(&request) {
718
719
720
721
722
723
724
        return Ok(resp.into_response());
    }

    // Handle non-text (image, audio, file) inputs - if Some(resp) is returned by
    // validate_input_is_text_only, then we are handling something other than Input::Text(_).
    // We will log an error message and early return a 501 NOT_IMPLEMENTED status code.
    // Otherwise, proceeed.
725
    if let Some(resp) = validate_response_input_is_text_only(&request) {
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
        return Ok(resp.into_response());
    }

    // Apply template values if present
    if let Some(template) = template {
        if request.inner.model.is_empty() {
            request.inner.model = template.model.clone();
        }
        if request.inner.temperature.unwrap_or(0.0) == 0.0 {
            request.inner.temperature = Some(template.temperature);
        }
        if request.inner.max_output_tokens.unwrap_or(0) == 0 {
            request.inner.max_output_tokens = Some(template.max_completion_tokens);
        }
    }
    tracing::trace!("Received chat completions request: {:?}", request.inner);

Ryan Olson's avatar
Ryan Olson committed
743
744
    let request_id = request.id().to_string();
    let (request, context) = request.into_parts();
745

Ryan Olson's avatar
Ryan Olson committed
746
    let mut request: NvCreateChatCompletionRequest = request.try_into().map_err(|e| {
747
748
749
750
751
        tracing::error!(
            request_id,
            "Failed to convert NvCreateResponse to NvCreateChatCompletionRequest: {:?}",
            e
        );
Ryan Olson's avatar
Ryan Olson committed
752
        ErrorMessage::not_implemented_error(&format!(
753
754
755
756
757
            "Only Input::Text(_) is currently supported: {}",
            e
        ))
    })?;

Ryan Olson's avatar
Ryan Olson committed
758
759
760
761
762
    let request = context.map(|mut _req| {
        request.inner.stream = Some(false);
        request
    });

763
764
765
766
767
768
769
    let model = &request.inner.model;

    tracing::trace!("Getting chat completions engine for model: {}", model);

    let engine = state
        .manager()
        .get_chat_completions_engine(model)
Ryan Olson's avatar
Ryan Olson committed
770
        .map_err(|_| ErrorMessage::model_not_found())?;
771

GuanLuo's avatar
GuanLuo committed
772
    let parsing_options = get_parsing_options(state.manager(), model);
773

774
775
776
777
778
779
780
781
782
783
784
785
786
    let mut inflight_guard =
        state
            .metrics_clone()
            .create_inflight_guard(model, Endpoint::Responses, false);

    let _response_collector = state.metrics_clone().create_response_collector(model);

    tracing::trace!("Issuing generate call for chat completions");

    // issue the generate call on the engine
    let stream = engine
        .generate(request)
        .await
Ryan Olson's avatar
Ryan Olson committed
787
        .map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;
788
789

    // TODO: handle streaming, currently just unary
790
791
792
793
794
795
796
797
798
799
800
801
802
803
    let response =
        NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
            .await
            .map_err(|e| {
                tracing::error!(
                    request_id,
                    "Failed to fold chat completions stream for: {:?}",
                    e
                );
                ErrorMessage::internal_server_error(&format!(
                    "Failed to fold chat completions stream: {}",
                    e
                ))
            })?;
804
805
806
807
808
809
810
811

    // Convert NvCreateChatCompletionResponse --> NvResponse
    let response: NvResponse = response.try_into().map_err(|e| {
        tracing::error!(
            request_id,
            "Failed to convert NvCreateChatCompletionResponse to NvResponse: {:?}",
            e
        );
Ryan Olson's avatar
Ryan Olson committed
812
        ErrorMessage::internal_server_error("Failed to convert internal response")
813
814
815
816
817
818
819
    })?;

    inflight_guard.mark_ok();

    Ok(Json(response).into_response())
}

820
821
822
pub fn validate_response_input_is_text_only(
    request: &NvCreateResponse,
) -> Option<impl IntoResponse> {
823
    match &request.inner.input {
824
        dynamo_async_openai::types::responses::Input::Text(_) => None,
825
826
827
        _ => Some(ErrorMessage::not_implemented_error(
            "Only `Input::Text` is supported. Structured, multimedia, or custom input types are not yet implemented.",
        )),
828
829
830
831
832
    }
}

/// Checks for unsupported fields in the request.
/// Returns Some(response) if unsupported fields are present.
833
834
835
pub fn validate_response_unsupported_fields(
    request: &NvCreateResponse,
) -> Option<impl IntoResponse> {
836
837
838
    let inner = &request.inner;

    if inner.background == Some(true) {
Ryan Olson's avatar
Ryan Olson committed
839
        return Some(ErrorMessage::not_implemented_error(
840
841
842
843
            "`background: true` is not supported.",
        ));
    }
    if inner.include.is_some() {
Ryan Olson's avatar
Ryan Olson committed
844
        return Some(ErrorMessage::not_implemented_error(
845
846
847
848
            "`include` is not supported.",
        ));
    }
    if inner.instructions.is_some() {
Ryan Olson's avatar
Ryan Olson committed
849
        return Some(ErrorMessage::not_implemented_error(
850
851
852
853
            "`instructions` is not supported.",
        ));
    }
    if inner.max_tool_calls.is_some() {
Ryan Olson's avatar
Ryan Olson committed
854
        return Some(ErrorMessage::not_implemented_error(
855
856
857
858
            "`max_tool_calls` is not supported.",
        ));
    }
    if inner.metadata.is_some() {
Ryan Olson's avatar
Ryan Olson committed
859
        return Some(ErrorMessage::not_implemented_error(
860
861
862
863
            "`metadata` is not supported.",
        ));
    }
    if inner.parallel_tool_calls == Some(true) {
Ryan Olson's avatar
Ryan Olson committed
864
        return Some(ErrorMessage::not_implemented_error(
865
866
867
868
            "`parallel_tool_calls: true` is not supported.",
        ));
    }
    if inner.previous_response_id.is_some() {
Ryan Olson's avatar
Ryan Olson committed
869
        return Some(ErrorMessage::not_implemented_error(
870
871
872
873
            "`previous_response_id` is not supported.",
        ));
    }
    if inner.prompt.is_some() {
Ryan Olson's avatar
Ryan Olson committed
874
        return Some(ErrorMessage::not_implemented_error(
875
876
877
878
            "`prompt` is not supported.",
        ));
    }
    if inner.reasoning.is_some() {
Ryan Olson's avatar
Ryan Olson committed
879
        return Some(ErrorMessage::not_implemented_error(
880
881
882
883
            "`reasoning` is not supported.",
        ));
    }
    if inner.service_tier.is_some() {
Ryan Olson's avatar
Ryan Olson committed
884
        return Some(ErrorMessage::not_implemented_error(
885
886
887
888
            "`service_tier` is not supported.",
        ));
    }
    if inner.store == Some(true) {
Ryan Olson's avatar
Ryan Olson committed
889
        return Some(ErrorMessage::not_implemented_error(
890
891
892
893
            "`store: true` is not supported.",
        ));
    }
    if inner.stream == Some(true) {
Ryan Olson's avatar
Ryan Olson committed
894
        return Some(ErrorMessage::not_implemented_error(
895
896
897
898
            "`stream: true` is not supported.",
        ));
    }
    if inner.text.is_some() {
Ryan Olson's avatar
Ryan Olson committed
899
        return Some(ErrorMessage::not_implemented_error(
900
901
902
903
            "`text` is not supported.",
        ));
    }
    if inner.tool_choice.is_some() {
Ryan Olson's avatar
Ryan Olson committed
904
        return Some(ErrorMessage::not_implemented_error(
905
906
907
908
            "`tool_choice` is not supported.",
        ));
    }
    if inner.tools.is_some() {
Ryan Olson's avatar
Ryan Olson committed
909
        return Some(ErrorMessage::not_implemented_error(
910
911
912
913
            "`tools` is not supported.",
        ));
    }
    if inner.truncation.is_some() {
Ryan Olson's avatar
Ryan Olson committed
914
        return Some(ErrorMessage::not_implemented_error(
915
916
917
918
            "`truncation` is not supported.",
        ));
    }
    if inner.user.is_some() {
Ryan Olson's avatar
Ryan Olson committed
919
        return Some(ErrorMessage::not_implemented_error(
920
921
922
923
924
925
926
            "`user` is not supported.",
        ));
    }

    None
}

927
928
// todo - abstract this to the top level lib.rs to be reused
// todo - move the service_observer to its own state/arc
Ryan Olson's avatar
Ryan Olson committed
929
fn check_ready(_state: &Arc<service_v2::State>) -> Result<(), ErrorResponse> {
930
    // if state.service_observer.stage() != ServiceStage::Ready {
Ryan Olson's avatar
Ryan Olson committed
931
    //     return Err(ErrorMessage::service_unavailable());
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
    // }
    Ok(())
}

/// openai compatible format
/// Example:
/// {
///  "object": "list",
///  "data": [
///    {
///      "id": "model-id-0",
///      "object": "model",
///      "created": 1686935002,
///      "owned_by": "organization-owner"
///    },
///    ]
/// }
async fn list_models_openai(
950
    State(state): State<Arc<service_v2::State>>,
Ryan Olson's avatar
Ryan Olson committed
951
) -> Result<Response, ErrorResponse> {
952
953
954
955
956
957
958
959
    check_ready(&state)?;

    let created = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap()
        .as_secs();
    let mut data = Vec::new();

960
961
    let models: HashSet<String> = state.manager().model_display_names();
    for model_name in models {
962
        data.push(ModelListing {
963
            id: model_name.clone(),
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
            object: "object",
            created,                        // Where would this come from? The GGUF?
            owned_by: "nvidia".to_string(), // Get organization from GGUF
        });
    }

    let out = ListModelOpenAI {
        object: "list",
        data,
    };
    Ok(Json(out).into_response())
}

#[derive(Serialize)]
struct ListModelOpenAI {
    object: &'static str, // always "list"
    data: Vec<ModelListing>,
}

#[derive(Serialize)]
struct ModelListing {
    id: String,
    object: &'static str, // always "object"
    created: u64,         //  Seconds since epoch
    owned_by: String,
}

struct EventConverter<T>(Annotated<T>);

impl<T> From<Annotated<T>> for EventConverter<T> {
    fn from(annotated: Annotated<T>) -> Self {
        EventConverter(annotated)
    }
}

999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
fn process_metrics_only<T>(
    annotated: &Annotated<T>,
    response_collector: &mut ResponseMetricCollector,
) {
    // update metrics
    if let Ok(Some(metrics)) = LLMMetricAnnotation::from_annotation(annotated) {
        response_collector.observe_current_osl(metrics.output_tokens);
        response_collector.observe_response(metrics.input_tokens, metrics.chunk_tokens);
    }
}

1010
1011
1012
1013
fn process_event_converter<T: Serialize>(
    annotated: EventConverter<T>,
    response_collector: &mut ResponseMetricCollector,
) -> Result<Event, axum::Error> {
1014
    let mut annotated = annotated.0;
1015

1016
1017
1018
1019
    // update metrics
    if let Ok(Some(metrics)) = LLMMetricAnnotation::from_annotation(&annotated) {
        response_collector.observe_current_osl(metrics.output_tokens);
        response_collector.observe_response(metrics.input_tokens, metrics.chunk_tokens);
1020
1021
1022
1023
1024
1025
1026

        // Chomp the LLMMetricAnnotation so it's not returned in the response stream
        // TODO: add a flag to control what is returned in the SSE stream
        if annotated.event.as_deref() == Some(crate::preprocessor::ANNOTATION_LLM_METRICS) {
            annotated.event = None;
            annotated.comment = None;
        }
1027
1028
    }

1029
    let mut event = Event::default();
1030

1031
1032
1033
    if let Some(data) = annotated.data {
        event = event.json_data(data)?;
    }
1034

1035
1036
1037
1038
1039
1040
    if let Some(msg) = annotated.event {
        if msg == "error" {
            let msgs = annotated
                .comment
                .unwrap_or_else(|| vec!["unspecified error".to_string()]);
            return Err(axum::Error::new(msgs.join(" -- ")));
1041
        }
1042
1043
        event = event.event(msg);
    }
1044

1045
1046
1047
1048
    if let Some(comments) = annotated.comment {
        for comment in comments {
            event = event.comment(comment);
        }
1049
    }
1050
1051

    Ok(event)
1052
1053
1054
1055
1056
}

/// Create an Axum [`Router`] for the OpenAI API Completions endpoint
/// If not path is provided, the default path is `/v1/completions`
pub fn completions_router(
1057
    state: Arc<service_v2::State>,
1058
1059
1060
1061
1062
    path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
    let path = path.unwrap_or("/v1/completions".to_string());
    let doc = RouteDoc::new(axum::http::Method::POST, &path);
    let router = Router::new()
Ryan Olson's avatar
Ryan Olson committed
1063
        .route(&path, post(handler_completions))
1064
        .layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
1065
1066
1067
1068
1069
1070
1071
        .with_state(state);
    (vec![doc], router)
}

/// Create an Axum [`Router`] for the OpenAI API Chat Completions endpoint
/// If not path is provided, the default path is `/v1/chat/completions`
pub fn chat_completions_router(
1072
    state: Arc<service_v2::State>,
1073
    template: Option<RequestTemplate>,
1074
1075
1076
1077
1078
    path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
    let path = path.unwrap_or("/v1/chat/completions".to_string());
    let doc = RouteDoc::new(axum::http::Method::POST, &path);
    let router = Router::new()
Ryan Olson's avatar
Ryan Olson committed
1079
        .route(&path, post(handler_chat_completions))
1080
        .layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
1081
        .with_state((state, template));
1082
1083
1084
    (vec![doc], router)
}

1085
1086
1087
/// Create an Axum [`Router`] for the OpenAI API Embeddings endpoint
/// If not path is provided, the default path is `/v1/embeddings`
pub fn embeddings_router(
1088
    state: Arc<service_v2::State>,
1089
1090
1091
1092
1093
1094
    path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
    let path = path.unwrap_or("/v1/embeddings".to_string());
    let doc = RouteDoc::new(axum::http::Method::POST, &path);
    let router = Router::new()
        .route(&path, post(embeddings))
1095
        .layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
1096
1097
1098
1099
        .with_state(state);
    (vec![doc], router)
}

1100
1101
/// List Models
pub fn list_models_router(
1102
    state: Arc<service_v2::State>,
1103
1104
1105
    path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
    // Standard OpenAI compatible list models endpoint
1106
    let openai_path = path.unwrap_or("/v1/models".to_string());
1107
1108
1109
1110
1111
1112
    let doc_for_openai = RouteDoc::new(axum::http::Method::GET, &openai_path);

    let router = Router::new()
        .route(&openai_path, get(list_models_openai))
        .with_state(state);

1113
    (vec![doc_for_openai], router)
1114
1115
}

1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
/// Create an Axum [`Router`] for the OpenAI API Responses endpoint
/// If not path is provided, the default path is `/v1/responses`
pub fn responses_router(
    state: Arc<service_v2::State>,
    template: Option<RequestTemplate>,
    path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
    let path = path.unwrap_or("/v1/responses".to_string());
    let doc = RouteDoc::new(axum::http::Method::POST, &path);
    let router = Router::new()
Ryan Olson's avatar
Ryan Olson committed
1126
        .route(&path, post(handler_responses))
1127
1128
1129
1130
        .with_state((state, template));
    (vec![doc], router)
}

1131
1132
#[cfg(test)]
mod tests {
1133
1134
    use std::collections::HashMap;

1135
1136
1137
1138
1139
1140
    use super::*;
    use crate::discovery::ModelManagerError;
    use crate::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
    use crate::protocols::openai::common_ext::CommonExt;
    use crate::protocols::openai::completions::NvCreateCompletionRequest;
    use crate::protocols::openai::responses::NvCreateResponse;
1141
    use dynamo_async_openai::types::responses::{
1142
1143
1144
1145
        CreateResponse, Input, InputContent, InputItem, InputMessage, PromptConfig,
        Role as ResponseRole, ServiceTier, TextConfig, TextResponseFormat, ToolChoice,
        ToolChoiceMode, Truncation,
    };
1146
    use dynamo_async_openai::types::{
1147
1148
        ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
        ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
1149
        CreateCompletionRequest,
1150
    };
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161

    const BACKUP_ERROR_MESSAGE: &str = "Failed to generate completions";

    fn http_error_from_engine(code: u16) -> Result<(), anyhow::Error> {
        Err(HttpError {
            code,
            message: "custom error message".to_string(),
        })?
    }

    fn other_error_from_engine() -> Result<(), anyhow::Error> {
1162
        Err(ModelManagerError::ModelNotFound("foo".to_string()))?
1163
1164
    }

1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
    fn make_base_request() -> NvCreateResponse {
        NvCreateResponse {
            inner: CreateResponse {
                input: Input::Text("hello".into()),
                model: "test-model".into(),
                background: None,
                include: None,
                instructions: None,
                max_output_tokens: None,
                max_tool_calls: None,
                metadata: None,
                parallel_tool_calls: None,
                previous_response_id: None,
                prompt: None,
                reasoning: None,
                service_tier: None,
                store: None,
                stream: None,
                text: None,
                tool_choice: None,
                tools: None,
                truncation: None,
                user: None,
                temperature: None,
                top_logprobs: None,
                top_p: None,
            },
            nvext: None,
        }
    }

1196
1197
1198
    #[test]
    fn test_http_error_response_from_anyhow() {
        let err = http_error_from_engine(400).unwrap_err();
Ryan Olson's avatar
Ryan Olson committed
1199
        let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
1200
1201
1202
1203
1204
1205
1206
        assert_eq!(status, StatusCode::BAD_REQUEST);
        assert_eq!(response.error, "custom error message");
    }

    #[test]
    fn test_error_response_from_anyhow_out_of_range() {
        let err = http_error_from_engine(399).unwrap_err();
Ryan Olson's avatar
Ryan Olson committed
1207
        let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
1208
1209
1210
1211
        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
        assert_eq!(response.error, "custom error message");

        let err = http_error_from_engine(500).unwrap_err();
Ryan Olson's avatar
Ryan Olson committed
1212
        let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
1213
1214
1215
1216
        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
        assert_eq!(response.error, "custom error message");

        let err = http_error_from_engine(501).unwrap_err();
Ryan Olson's avatar
Ryan Olson committed
1217
        let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
1218
1219
1220
1221
1222
1223
1224
        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
        assert_eq!(response.error, "custom error message");
    }

    #[test]
    fn test_other_error_response_from_anyhow() {
        let err = other_error_from_engine().unwrap_err();
Ryan Olson's avatar
Ryan Olson committed
1225
        let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
        assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
        assert_eq!(
            response.error,
            format!(
                "{}: {}",
                BACKUP_ERROR_MESSAGE,
                other_error_from_engine().unwrap_err()
            )
        );
    }
1236

1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
    #[test]
    fn test_service_overloaded_error_response_from_anyhow() {
        use dynamo_runtime::pipeline::error::PipelineError;

        let err: anyhow::Error = PipelineError::ServiceOverloaded(
            "All workers are busy, please retry later".to_string(),
        )
        .into();
        let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
        assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE);
        assert_eq!(
            response.error,
            "Service temporarily unavailable: All workers are busy, please retry later"
        );
    }

1253
1254
1255
    #[test]
    fn test_validate_input_is_text_only_accepts_text() {
        let request = make_base_request();
1256
        let result = validate_response_input_is_text_only(&request);
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
        assert!(result.is_none());
    }

    #[test]
    fn test_validate_input_is_text_only_rejects_items() {
        let mut request = make_base_request();
        request.inner.input = Input::Items(vec![InputItem::Message(InputMessage {
            kind: Default::default(),
            role: ResponseRole::User,
            content: InputContent::TextInput("structured".into()),
        })]);
1268
        let result = validate_response_input_is_text_only(&request);
1269
1270
1271
1272
1273
1274
        assert!(result.is_some());
    }

    #[test]
    fn test_validate_unsupported_fields_accepts_clean_request() {
        let request = make_base_request();
1275
        let result = validate_response_unsupported_fields(&request);
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
        assert!(result.is_none());
    }

    #[test]
    fn test_validate_unsupported_fields_detects_flags() {
        #[allow(clippy::type_complexity)]
        let unsupported_cases: Vec<(&str, Box<dyn FnOnce(&mut CreateResponse)>)> = vec![
            ("background", Box::new(|r| r.background = Some(true))),
            (
                "include",
                Box::new(|r| r.include = Some(vec!["file_search_call.results".into()])),
            ),
            (
                "instructions",
                Box::new(|r| r.instructions = Some("System prompt".into())),
            ),
            ("max_tool_calls", Box::new(|r| r.max_tool_calls = Some(3))),
            ("metadata", Box::new(|r| r.metadata = Some(HashMap::new()))),
            (
                "parallel_tool_calls",
                Box::new(|r| r.parallel_tool_calls = Some(true)),
            ),
            (
                "previous_response_id",
                Box::new(|r| r.previous_response_id = Some("prev-id".into())),
            ),
            (
                "prompt",
                Box::new(|r| {
                    r.prompt = Some(PromptConfig {
                        id: "template-id".into(),
                        version: None,
                        variables: None,
                    })
                }),
            ),
            (
                "reasoning",
                Box::new(|r| r.reasoning = Some(Default::default())),
            ),
            (
                "service_tier",
                Box::new(|r| r.service_tier = Some(ServiceTier::Auto)),
            ),
            ("store", Box::new(|r| r.store = Some(true))),
            ("stream", Box::new(|r| r.stream = Some(true))),
            (
                "text",
                Box::new(|r| {
                    r.text = Some(TextConfig {
                        format: TextResponseFormat::Text,
                    })
                }),
            ),
            (
                "tool_choice",
                Box::new(|r| r.tool_choice = Some(ToolChoice::Mode(ToolChoiceMode::Required))),
            ),
            ("tools", Box::new(|r| r.tools = Some(vec![]))),
            (
                "truncation",
                Box::new(|r| r.truncation = Some(Truncation::Auto)),
            ),
            ("user", Box::new(|r| r.user = Some("user-id".into()))),
        ];

        for (field, set_field) in unsupported_cases {
            let mut req = make_base_request();
            (set_field)(&mut req.inner);
1345
            let result = validate_response_unsupported_fields(&req);
1346
1347
1348
            assert!(result.is_some(), "Expected rejection for `{field}`");
        }
    }
1349
1350
1351
1352
1353
1354
1355
1356
1357

    #[test]
    fn test_validate_chat_completion_required_fields_empty_messages() {
        let request = NvCreateChatCompletionRequest {
            inner: CreateChatCompletionRequest {
                model: "test-model".to_string(),
                messages: vec![],
                ..Default::default()
            },
1358
            common: Default::default(),
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
            nvext: None,
        };
        let result = validate_chat_completion_required_fields(&request);
        assert!(result.is_err());
        if let Err((status, error_response)) = result {
            assert_eq!(status, StatusCode::BAD_REQUEST);
            assert_eq!(
                error_response.error,
                "The 'messages' field cannot be empty. At least one message is required."
            );
        }
    }

    #[test]
    fn test_validate_chat_completion_required_fields_with_messages() {
        let request = NvCreateChatCompletionRequest {
            inner: CreateChatCompletionRequest {
                model: "test-model".to_string(),
                messages: vec![ChatCompletionRequestMessage::User(
                    ChatCompletionRequestUserMessage {
                        content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
                        name: None,
                    },
                )],
                ..Default::default()
            },
1385
            common: Default::default(),
1386
1387
1388
1389
1390
            nvext: None,
        };
        let result = validate_chat_completion_required_fields(&request);
        assert!(result.is_ok());
    }
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702

    #[test]
    // Test for all Bad Requests Example for Chat Completion
    // 1. Echo:  Should be a boolean : Not Done
    // 2. Frequency Penalty: Should be a float between -2.0 and 2.0 : Done
    // 3. logprobs: Done
    // 4. Model Format: Should be a string : Not Done
    // 5. Prompt or Messages Validation
    // 6. Max Tokens: Should be a positive integer
    // 7. Presence Penalty: Should be a float between -2.0 and 2.0 : Done
    // 8. Stop : Should be a string or an array of strings : Not Done
    // 9. Invalid or Out of range temperature: Done
    // 10.Invalid or out of range top_p: Done
    // 11. Repetition Penalty: Should be a float between 0.0 and 2.0 : Done
    // 12. Logprobs: Should be a positive integer between 0 and 5 : Done
    // invalid or non existing user : Only empty string is not allowed validation is there. How can we check non-extisting user ?
    // add_special_tokens null or invalid : Not Done
    // guided_whitespace_pattern null or invalid : Not Done
    // "response_format": { "type": "invalid_format" } : Not Done
    // "logit_bias": { "invalid_token": "not_a_number" }, : Partial Validation is already there
    fn test_bad_base_request_for_completion() {
        // Frequency Penalty: Should be a float between -2.0 and 2.0
        let request = NvCreateCompletionRequest {
            inner: CreateCompletionRequest {
                model: "test-model".to_string(),
                prompt: "Hello".into(),
                frequency_penalty: Some(-3.0),
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
        };

        let result = validate_completion_fields_generic(&request);
        assert!(result.is_err());
        if let Err((status, error_response)) = result {
            assert_eq!(status, StatusCode::BAD_REQUEST);
            assert_eq!(
                error_response.error,
                "Frequency penalty must be between -2 and 2, got -3"
            );
        }

        // Presence Penalty: Should be a float between -2.0 and 2.0
        let request = NvCreateCompletionRequest {
            inner: CreateCompletionRequest {
                model: "test-model".to_string(),
                prompt: "Hello".into(),
                presence_penalty: Some(-3.0),
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
        };
        let result = validate_completion_fields_generic(&request);
        assert!(result.is_err());
        if let Err((status, error_response)) = result {
            assert_eq!(status, StatusCode::BAD_REQUEST);
            assert_eq!(
                error_response.error,
                "Presence penalty must be between -2 and 2, got -3"
            );
        }

        // Temperature: Should be a float between 0.0 and 2.0
        let request = NvCreateCompletionRequest {
            inner: CreateCompletionRequest {
                model: "test-model".to_string(),
                prompt: "Hello".into(),
                temperature: Some(-3.0),
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
        };
        let result = validate_completion_fields_generic(&request);
        assert!(result.is_err());
        if let Err((status, error_response)) = result {
            assert_eq!(status, StatusCode::BAD_REQUEST);
            assert_eq!(
                error_response.error,
                "Temperature must be between 0 and 2, got -3"
            );
        }

        // Top P: Should be a float between 0.0 and 1.0
        let request = NvCreateCompletionRequest {
            inner: CreateCompletionRequest {
                model: "test-model".to_string(),
                prompt: "Hello".into(),
                top_p: Some(-3.0),
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
        };
        let result = validate_completion_fields_generic(&request);
        assert!(result.is_err());
        if let Err((status, error_response)) = result {
            assert_eq!(status, StatusCode::BAD_REQUEST);
            assert_eq!(
                error_response.error,
                "Top_p must be between 0 and 1, got -3"
            );
        }

        // Repetition Penalty: Should be a float between 0.0 and 2.0
        let request = NvCreateCompletionRequest {
            inner: CreateCompletionRequest {
                model: "test-model".to_string(),
                prompt: "Hello".into(),
                ..Default::default()
            },
            common: CommonExt::builder()
                .repetition_penalty(-3.0)
                .build()
                .unwrap(),
            nvext: None,
        };
        let result = validate_completion_fields_generic(&request);
        assert!(result.is_err());
        if let Err((status, error_response)) = result {
            assert_eq!(status, StatusCode::BAD_REQUEST);
            assert_eq!(
                error_response.error,
                "Repetition penalty must be between 0 and 2, got -3"
            );
        }

        // Logprobs: Should be a positive integer between 0 and 5
        let request = NvCreateCompletionRequest {
            inner: CreateCompletionRequest {
                model: "test-model".to_string(),
                prompt: "Hello".into(),
                logprobs: Some(6),
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
        };
        let result = validate_completion_fields_generic(&request);
        assert!(result.is_err());
        if let Err((status, error_response)) = result {
            assert_eq!(status, StatusCode::BAD_REQUEST);
            assert_eq!(
                error_response.error,
                "Logprobs must be between 0 and 5, got 6"
            );
        }
    }

    #[test]
    fn test_bad_base_request_for_chatcompletion() {
        // Frequency Penalty: Should be a float between -2.0 and 2.0
        let request = NvCreateChatCompletionRequest {
            inner: CreateChatCompletionRequest {
                model: "test-model".to_string(),
                messages: vec![ChatCompletionRequestMessage::User(
                    ChatCompletionRequestUserMessage {
                        content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
                        name: None,
                    },
                )],
                frequency_penalty: Some(-3.0),
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
        };

        let result = validate_chat_completion_fields_generic(&request);
        assert!(result.is_err());
        if let Err((status, error_response)) = result {
            assert_eq!(status, StatusCode::BAD_REQUEST);
            assert_eq!(
                error_response.error,
                "Frequency penalty must be between -2 and 2, got -3"
            );
        }

        // Presence Penalty: Should be a float between -2.0 and 2.0
        let request = NvCreateChatCompletionRequest {
            inner: CreateChatCompletionRequest {
                model: "test-model".to_string(),
                messages: vec![ChatCompletionRequestMessage::User(
                    ChatCompletionRequestUserMessage {
                        content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
                        name: None,
                    },
                )],
                presence_penalty: Some(-3.0),
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
        };
        let result = validate_chat_completion_fields_generic(&request);
        assert!(result.is_err());
        if let Err((status, error_response)) = result {
            assert_eq!(status, StatusCode::BAD_REQUEST);
            assert_eq!(
                error_response.error,
                "Presence penalty must be between -2 and 2, got -3"
            );
        }

        // Temperature: Should be a float between 0.0 and 2.0
        let request = NvCreateChatCompletionRequest {
            inner: CreateChatCompletionRequest {
                model: "test-model".to_string(),
                messages: vec![ChatCompletionRequestMessage::User(
                    ChatCompletionRequestUserMessage {
                        content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
                        name: None,
                    },
                )],
                temperature: Some(-3.0),
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
        };
        let result = validate_chat_completion_fields_generic(&request);
        assert!(result.is_err());
        if let Err((status, error_response)) = result {
            assert_eq!(status, StatusCode::BAD_REQUEST);
            assert_eq!(
                error_response.error,
                "Temperature must be between 0 and 2, got -3"
            );
        }

        // Top P: Should be a float between 0.0 and 1.0
        let request = NvCreateChatCompletionRequest {
            inner: CreateChatCompletionRequest {
                model: "test-model".to_string(),
                messages: vec![ChatCompletionRequestMessage::User(
                    ChatCompletionRequestUserMessage {
                        content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
                        name: None,
                    },
                )],
                top_p: Some(-3.0),
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
        };
        let result = validate_chat_completion_fields_generic(&request);
        assert!(result.is_err());
        if let Err((status, error_response)) = result {
            assert_eq!(status, StatusCode::BAD_REQUEST);
            assert_eq!(
                error_response.error,
                "Top_p must be between 0 and 1, got -3"
            );
        }

        // Repetition Penalty: Should be a float between 0.0 and 2.0
        let request = NvCreateChatCompletionRequest {
            inner: CreateChatCompletionRequest {
                model: "test-model".to_string(),
                messages: vec![ChatCompletionRequestMessage::User(
                    ChatCompletionRequestUserMessage {
                        content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
                        name: None,
                    },
                )],
                ..Default::default()
            },
            common: CommonExt::builder()
                .repetition_penalty(-3.0)
                .build()
                .unwrap(),
            nvext: None,
        };
        let result = validate_chat_completion_fields_generic(&request);
        assert!(result.is_err());
        if let Err((status, error_response)) = result {
            assert_eq!(status, StatusCode::BAD_REQUEST);
            assert_eq!(
                error_response.error,
                "Repetition penalty must be between 0 and 2, got -3"
            );
        }

        // Top Logprobs: Should be a positive integer between 0 and 20
        let request = NvCreateChatCompletionRequest {
            inner: CreateChatCompletionRequest {
                model: "test-model".to_string(),
                messages: vec![ChatCompletionRequestMessage::User(
                    ChatCompletionRequestUserMessage {
                        content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
                        name: None,
                    },
                )],
                top_logprobs: Some(25),
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
        };
        let result = validate_chat_completion_fields_generic(&request);
        assert!(result.is_err());
        if let Err((status, error_response)) = result {
            assert_eq!(status, StatusCode::BAD_REQUEST);
            assert_eq!(
                error_response.error,
                "Top_logprobs must be between 0 and 20, got 25"
            );
        }
    }
1703
}