preprocessor.rs 46.3 KB
Newer Older
Biswa Panda's avatar
Biswa Panda committed
1
2
3
4
5
6
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! The Preprocessor consists of the following modules
//!
//! - `translation`: This module converts the allowed Ingress message types to the corresponding
7
//!   internal representation.
Biswa Panda's avatar
Biswa Panda committed
8
9
10
11
12
13
14
15
16
17
//! - `apply`: This module applies ModelConfig defaults to any empty optional fields specified
//! - `prompt`: This module applies any prompt template logic to the internal Request object.
//! - `tokenize`: This module tokenizes the formatted prompt string and returns the token ids.
//!
//! The Preprocessor will accept any IngressRequest and transform it to a BackendRequest.

pub mod prompt;
pub mod tools;

use anyhow::Result;
18
use dynamo_async_openai::types::ChatCompletionToolChoiceOption;
19
use dynamo_async_openai::types::EncodingFormat;
Biswa Panda's avatar
Biswa Panda committed
20
21
use futures::stream::{self, StreamExt};
use prompt::OAIPromptFormatter;
22
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
Biswa Panda's avatar
Biswa Panda committed
23
24
25
use std::{collections::HashMap, sync::Arc};
use tracing;

26
27
28
29
use dynamo_parsers::tool_calling::{
    parsers::detect_tool_call_start, try_tool_call_parse_aggregate,
};

30
use crate::model_card::{ModelDeploymentCard, ModelInfo};
Biswa Panda's avatar
Biswa Panda committed
31
use crate::preprocessor::prompt::OAIChatLikeRequest;
32
use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
33
use crate::tokenizers::Encoding;
Biswa Panda's avatar
Biswa Panda committed
34

Neelay Shah's avatar
Neelay Shah committed
35
36
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::{
37
    AsyncEngineContext, Error, ManyOut, Operator, SingleIn, async_trait,
Biswa Panda's avatar
Biswa Panda committed
38
};
Neelay Shah's avatar
Neelay Shah committed
39
use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider};
Biswa Panda's avatar
Biswa Panda committed
40
41

use crate::protocols::{
Greg Clark's avatar
Greg Clark committed
42
    common::{OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
Biswa Panda's avatar
Biswa Panda committed
43
    openai::{
44
        DeltaGeneratorExt,
45
        chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
46
        completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
47
        embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
Biswa Panda's avatar
Biswa Panda committed
48
49
50
        nvext::NvExtProvider,
    },
};
51
use crate::tokenizers::{HuggingFaceTokenizer, traits::Tokenizer};
Biswa Panda's avatar
Biswa Panda committed
52

53
use crate::preprocessor::prompt::{PromptFormatter, PromptInput, TextInput, TokenInput};
Biswa Panda's avatar
Biswa Panda committed
54

55
pub use crate::protocols::common::llm_backend::{BackendOutput, PreprocessedRequest};
56
57
58
pub use crate::protocols::common::preprocessor::PreprocessedEmbeddingRequest;

use crate::protocols::common::llm_backend::EmbeddingsEngineOutput;
Biswa Panda's avatar
Biswa Panda committed
59
60
61

pub const ANNOTATION_FORMATTED_PROMPT: &str = "formatted_prompt";
pub const ANNOTATION_TOKEN_IDS: &str = "token_ids";
62
pub const ANNOTATION_LLM_METRICS: &str = "llm_metrics";
63
pub const ANNOTATION_POSSIBLE_TOOL_CALL: &str = "possible_tool_call";
64
65
66
67
68
69
70
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LLMMetricAnnotation {
    pub input_tokens: usize,
    pub output_tokens: usize,
    pub chunk_tokens: usize,
}

71
72
73
74
75
76
77
78
79
80
#[derive(Debug)]
pub struct JailState {
    stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
    is_jailed: bool,
    tool_call_parser: Option<String>,
    accumulated_content: HashMap<u32, String>, // choice index -> accumulated content
    last_response_metadata: Option<NvCreateChatCompletionStreamResponse>, // for response structure
    finished: bool,                            // Add this flag to track if stream is finished
}

81
82
83
84
85
86
87
88
89
90
91
92
93
94
pub fn maybe_enable_tool_call(
    parser_str: Option<&str>,
    request: &NvCreateChatCompletionRequest,
) -> bool {
    // Enable tool call if the below two conditions are satisfied
    // 1. parser_str is not None
    // 2. tool_choice is not None
    parser_str.is_some()
        && !matches!(
            request.inner.tool_choice,
            Some(ChatCompletionToolChoiceOption::None)
        )
}

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
impl LLMMetricAnnotation {
    /// Convert this metrics struct to an Annotated event
    pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> {
        Annotated::from_annotation(ANNOTATION_LLM_METRICS, self)
    }

    /// Extract LLM metrics from an Annotated event, if present
    pub fn from_annotation<T>(
        annotation: &Annotated<T>,
    ) -> Result<Option<LLMMetricAnnotation>, Box<dyn std::error::Error>> {
        if annotation.event.is_none() {
            return Ok(None);
        }
        if annotation.event.as_ref().unwrap() != ANNOTATION_LLM_METRICS {
            return Ok(None);
        }
        let comments = annotation
            .comment
            .as_ref()
            .ok_or("missing comments block")?;
        if comments.len() != 1 {
            return Err("malformed comments block - expected exactly 1 comment".into());
        }
        let metrics: LLMMetricAnnotation = serde_json::from_str(&comments[0])?;
        Ok(Some(metrics))
    }
}
Biswa Panda's avatar
Biswa Panda committed
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PossibleToolCallAnnotation {
    pub possible_tokens: usize,
    pub possible_content: String,
    pub parser_used: Option<String>,
}

impl PossibleToolCallAnnotation {
    /// Convert this possible tool call annotation to an Annotated event
    pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> {
        Annotated::from_annotation(ANNOTATION_POSSIBLE_TOOL_CALL, self)
    }

    /// Extract possible tool call info from an Annotated event, if present
    pub fn from_annotation<T>(
        annotation: &Annotated<T>,
    ) -> Result<Option<PossibleToolCallAnnotation>, Box<dyn std::error::Error>> {
        if annotation.event.is_none() {
            return Ok(None);
        }
        if annotation.event.as_ref().unwrap() != ANNOTATION_POSSIBLE_TOOL_CALL {
            return Ok(None);
        }
        let comments = annotation
            .comment
            .as_ref()
            .ok_or("missing comments block")?;
        if comments.len() != 1 {
            return Err("malformed comments block - expected exactly 1 comment".into());
        }
        let possible_info: PossibleToolCallAnnotation = serde_json::from_str(&comments[0])?;
        Ok(Some(possible_info))
    }
}

Biswa Panda's avatar
Biswa Panda committed
158
159
160
161
162
pub struct OpenAIPreprocessor {
    mdcsum: String,
    formatter: Arc<dyn OAIPromptFormatter>,
    tokenizer: Arc<dyn Tokenizer>,
    model_info: Arc<dyn ModelInfo>,
163
164
    /// Per-model runtime configuration propagated to response generator (e.g., reasoning/tool parser)
    runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig,
165
    tool_call_parser: Option<String>,
Biswa Panda's avatar
Biswa Panda committed
166
167
168
}

impl OpenAIPreprocessor {
169
170
171
    pub fn new(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
        let formatter = PromptFormatter::from_mdc(&mdc)?;
        let tokenizer = mdc.tokenizer_hf()?;
172
        match formatter {
173
            PromptFormatter::OAI(formatter) => Self::new_with_parts(mdc, formatter, tokenizer),
174
175
176
        }
    }

177
    pub fn new_with_parts(
178
179
        mdc: ModelDeploymentCard,
        formatter: Arc<dyn OAIPromptFormatter>,
180
        hf_tokenizer: tokenizers::Tokenizer,
181
182
    ) -> Result<Arc<Self>> {
        let mdcsum = mdc.mdcsum();
183
        let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer));
184
185
186
187
188
        let Some(model_info) = mdc.model_info else {
            anyhow::bail!(
                "Blank ModelDeploymentCard cannot be used for pre-processing, no model_info"
            );
        };
189
        let model_info = model_info.get_model_info()?;
190
        let tool_call_parser = mdc.runtime_config.tool_call_parser.clone();
Biswa Panda's avatar
Biswa Panda committed
191

192
193
194
        // // Initialize runtime config from the ModelDeploymentCard
        let runtime_config = mdc.runtime_config.clone();

Biswa Panda's avatar
Biswa Panda committed
195
196
197
198
199
        Ok(Arc::new(Self {
            formatter,
            tokenizer,
            model_info,
            mdcsum,
200
            runtime_config,
201
            tool_call_parser,
Biswa Panda's avatar
Biswa Panda committed
202
203
        }))
    }
204
205
206
207
208
    /// Encode a string to it's tokens
    pub fn tokenize(&self, s: &str) -> anyhow::Result<Encoding> {
        self.tokenizer.encode(s)
    }

209
    /// Translate a [`NvCreateChatCompletionRequest`] request to a common completion request.
Biswa Panda's avatar
Biswa Panda committed
210
211
212
213
214
215
216
217
218
219
    /// Returns both the common completion request and a hashmap of annotations.
    ///
    /// Annotations evaluated by this method include:
    /// - `formatted_prompt`
    /// - `token_ids`
    pub fn preprocess_request<
        R: OAIChatLikeRequest
            + AnnotationsProvider
            + SamplingOptionsProvider
            + StopConditionsProvider
Greg Clark's avatar
Greg Clark committed
220
            + OutputOptionsProvider
Biswa Panda's avatar
Biswa Panda committed
221
222
223
224
            + NvExtProvider,
    >(
        &self,
        request: &R,
225
    ) -> Result<(PreprocessedRequest, HashMap<String, String>)> {
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        let mut builder = self.builder(request)?;
        let formatted_prompt = self.apply_template(request)?;
        let annotations = self.gather_tokens(request, &mut builder, formatted_prompt)?;

        Ok((builder.build()?, annotations))
    }

    pub fn builder<
        R: OAIChatLikeRequest
            + AnnotationsProvider
            + SamplingOptionsProvider
            + StopConditionsProvider
            + OutputOptionsProvider
            + NvExtProvider,
    >(
        &self,
        request: &R,
    ) -> Result<PreprocessedRequestBuilder> {
244
        let mut builder = PreprocessedRequest::builder();
245
        builder.model(request.model());
Biswa Panda's avatar
Biswa Panda committed
246

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        let mut stop_conditions = request.extract_stop_conditions()?;
        if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
            for eos_token in self.model_info.eos_token_ids() {
                if !stop_tokens.contains(&eos_token) {
                    stop_tokens.push(eos_token);
                }
            }
        } else {
            stop_conditions.stop_token_ids_hidden = Some(self.model_info.eos_token_ids());
        }

        // apply ignore eos if not already set
        stop_conditions.apply_ignore_eos();

        if !stop_conditions.ignore_eos.unwrap_or(false) {
            builder.eos_token_ids(self.model_info.eos_token_ids());
        }

        builder.stop_conditions(stop_conditions);
        builder.sampling_options(request.extract_sampling_options()?);
        builder.output_options(request.extract_output_options()?);
        builder.annotations(request.annotations().unwrap_or_default());
        builder.mdc_sum(Some(self.mdcsum.clone()));
        builder.estimated_prefix_hit_num_blocks(None);
        // Extract backend_instance_id from nvext if present
        if let Some(nvext) = request.nvext() {
            builder.backend_instance_id(nvext.backend_instance_id);
        }

        Ok(builder)
    }

    pub fn apply_template<
        R: OAIChatLikeRequest
            + AnnotationsProvider
            + SamplingOptionsProvider
            + StopConditionsProvider
            + OutputOptionsProvider
            + NvExtProvider,
    >(
        &self,
        request: &R,
    ) -> Result<Option<String>> {
        if let PromptInput::Text(_) = request.prompt_input_type()
            && let Some(TextInput::Single(_)) = request.extract_text()
        {
            let use_raw_prompt = request
                .nvext()
                .is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));

            let formatted_prompt = if use_raw_prompt {
                match request.raw_prompt() {
                    Some(prompt) => prompt,
                    None => {
                        tracing::warn!("Raw prompt requested but not available");
                        self.formatter.render(request)?
                    }
                }
            } else {
                self.formatter.render(request)?
            };
            Ok(Some(formatted_prompt))
        } else {
            Ok(None)
        }
    }

    pub fn gather_tokens<
        R: OAIChatLikeRequest
            + AnnotationsProvider
            + SamplingOptionsProvider
            + StopConditionsProvider
            + OutputOptionsProvider
            + NvExtProvider,
    >(
        &self,
        request: &R,
        builder: &mut PreprocessedRequestBuilder,
        formatted_prompt: Option<String>,
    ) -> Result<HashMap<String, String>> {
        let mut annotations = HashMap::new();
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        // match request type before any conversion/processing
        match request.prompt_input_type() {
            PromptInput::Tokens(_) => {
                if let Some(token_input) = request.extract_tokens() {
                    match token_input {
                        TokenInput::Single(tokens) => {
                            builder.token_ids(tokens);
                        }
                        TokenInput::Batch(token_batches) => {
                            if token_batches.len() == 1 {
                                builder.token_ids(token_batches[0].clone());
                            } else {
                                builder.batch_token_ids(Some(token_batches));
                                builder.token_ids(vec![]);
                            }
                        }
                    }
Biswa Panda's avatar
Biswa Panda committed
345
346
                }
            }
347
348
349
            PromptInput::Text(_) => {
                if let Some(text_input) = request.extract_text() {
                    match text_input {
350
351
352
353
354
355
356
357
358
359
                        TextInput::Single(raw_prompt) => {
                            if let Some(f) = formatted_prompt.as_ref()
                                && request.has_annotation(ANNOTATION_FORMATTED_PROMPT)
                            {
                                annotations
                                    .insert(ANNOTATION_FORMATTED_PROMPT.to_string(), f.to_string());
                            }

                            // Completions will use raw_prompt, no template
                            let prompt = formatted_prompt.unwrap_or(raw_prompt);
Biswa Panda's avatar
Biswa Panda committed
360

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
                            // Check if backend_instance_id is present and token_data is provided
                            let has_backend_instance_id = request
                                .nvext()
                                .and_then(|ext| ext.backend_instance_id)
                                .is_some();

                            let token_data =
                                request.nvext().and_then(|ext| ext.token_data.as_ref());

                            let (tokens_vec, skip_token_annotation) = if has_backend_instance_id {
                                if let Some(tokens) = token_data {
                                    tracing::trace!(
                                        "Using provided tokens from EPP: {} ids",
                                        tokens.len()
                                    );
                                    // need ownership for the builder, so clone.
                                    (tokens.clone(), true)
                                } else {
                                    tracing::warn!(
                                        "backend_instance_id provided but no token_data; tokenizing prompt"
                                    );
382
                                    let encoding = self.tokenizer.encode(&prompt)?;
383
384
385
386
                                    (encoding.token_ids().to_vec(), false)
                                }
                            } else {
                                // No backend_instance_id provided, continue the normal flow.
387
                                let encoding = self.tokenizer.encode(&prompt)?;
388
389
                                (encoding.token_ids().to_vec(), false)
                            };
Biswa Panda's avatar
Biswa Panda committed
390

391
392
393
                            if request.has_annotation(ANNOTATION_TOKEN_IDS)
                                && !skip_token_annotation
                            {
394
395
                                annotations.insert(
                                    ANNOTATION_TOKEN_IDS.to_string(),
396
                                    serde_json::to_string(&tokens_vec)?,
397
398
399
                                );
                            }

400
                            builder.token_ids(tokens_vec);
401
402
                        }
                        TextInput::Batch(texts) => {
403
                            let token_batches: Vec<Vec<u32>> = texts
404
405
                                .par_iter()
                                .map(|text| {
406
407
408
                                    self.tokenizer
                                        .encode(text)
                                        .map(|encoded| encoded.token_ids().to_vec())
409
                                })
410
                                .collect::<Result<Vec<_>>>()?;
411
412
413
414
415
416
                            builder.batch_token_ids(Some(token_batches));
                            builder.token_ids(vec![]);
                        }
                    }
                }
            }
Biswa Panda's avatar
Biswa Panda committed
417
        }
418
        Ok(annotations)
Biswa Panda's avatar
Biswa Panda committed
419
420
    }

421
422
423
424
425
426
427
428
429
430
431
432
433
434
    /// Preprocess an embedding request, handling both text and token ID inputs.
    ///
    /// For text inputs, tokenizes the text using the configured tokenizer.
    /// For token ID inputs, uses the provided token IDs directly and skips tokenization.
    ///
    /// Returns both the preprocessed request and a hashmap of annotations.
    pub async fn preprocess_embedding_request(
        &self,
        request: &NvCreateEmbeddingRequest,
    ) -> Result<(PreprocessedEmbeddingRequest, HashMap<String, String>)> {
        let mut annotations = HashMap::new();
        let mut builder = PreprocessedEmbeddingRequest::builder();

        let all_token_ids = match &request.inner.input {
435
            dynamo_async_openai::types::EmbeddingInput::String(s) => {
436
437
                let encoding = self.tokenizer.encode(s)?;
                vec![encoding.token_ids().to_vec()]
438
            }
439
            dynamo_async_openai::types::EmbeddingInput::StringArray(arr) => {
440
441
442
443
444
445
446
447
448
449
450
                let input_strs: Vec<String> = arr.to_vec();
                let encodings = tokio::task::spawn_blocking({
                    let tokenizer = self.tokenizer.clone();
                    let strs = input_strs.clone();
                    move || {
                        tokenizer.encode_batch(&strs.iter().map(|s| s.as_str()).collect::<Vec<_>>())
                    }
                })
                .await??;
                let token_arrays: Vec<Vec<u32>> = encodings
                    .into_iter()
451
                    .map(|encoding| encoding.token_ids().to_vec())
452
453
454
                    .collect();
                token_arrays
            }
455
456
457
458
            dynamo_async_openai::types::EmbeddingInput::IntegerArray(token_ids) => {
                vec![token_ids.clone()]
            }
            dynamo_async_openai::types::EmbeddingInput::ArrayOfIntegerArray(token_arrays) => {
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
                token_arrays.clone()
            }
        };

        // Handle annotations
        if request.has_annotation(ANNOTATION_TOKEN_IDS) {
            annotations.insert(
                ANNOTATION_TOKEN_IDS.to_string(),
                serde_json::to_string(&all_token_ids)?,
            );
        }

        builder.token_ids(all_token_ids);
        builder.model(request.inner.model.clone());
        builder.encoding_format(request.inner.encoding_format.as_ref().map(|f| match f {
            EncodingFormat::Float => "float".to_string(),
            EncodingFormat::Base64 => "base64".to_string(),
        }));
        builder.dimensions(request.inner.dimensions);

        builder.annotations(request.annotations().unwrap_or_default());
        builder.mdc_sum(Some(self.mdcsum.clone()));

        Ok((builder.build()?, annotations))
    }

Biswa Panda's avatar
Biswa Panda committed
485
486
487
488
489
490
491
492
493
494
495
    pub fn transform_postprocessor_stream<Resp: Send + Sync + 'static + std::fmt::Debug>(
        stream: ManyOut<Annotated<BackendOutput>>,
        generator: Box<dyn DeltaGeneratorExt<Resp>>,
    ) -> ManyOut<Annotated<Resp>> {
        let context = stream.context();

        struct State<Resp: Send + Sync + 'static + std::fmt::Debug> {
            response_stream: ManyOut<Annotated<BackendOutput>>,
            response_generator: Box<dyn DeltaGeneratorExt<Resp>>,
            context: Arc<dyn AsyncEngineContext>,
            cancelled: bool,
496
            cumulative_output_tokens: usize,
497
498
            finish_reason_sent: bool,
            usage_chunk_sent: bool,
499
            finished: bool, // Add this flag to track if stream is finished
Biswa Panda's avatar
Biswa Panda committed
500
501
502
503
504
505
506
        }

        let state = State {
            response_stream: stream,
            response_generator: generator,
            context: context.clone(),
            cancelled: false,
507
            cumulative_output_tokens: 0,
508
509
            finish_reason_sent: false,
            usage_chunk_sent: false,
510
            finished: false, // Initialize as not finished
Biswa Panda's avatar
Biswa Panda committed
511
512
513
514
515
        };

        // transform the common response stream into a chat response stream
        let stream = stream::unfold(state, |mut inner| {
            async move {
516
517
518
519
520
                // If already finished, return None immediately
                if inner.finished {
                    return None;
                }

Biswa Panda's avatar
Biswa Panda committed
521
522
523
524
525
526
                if let Some(response) = inner.response_stream.next().await {
                    if inner.cancelled {
                        tracing::debug!(
                            request_id = inner.context.id(),
                            "Cancellation issued last message; closing stream"
                        );
527
                        inner.finished = true; // Mark as finished
Biswa Panda's avatar
Biswa Panda committed
528
529
530
531
532
533
534
535
536
                        return None;
                    }

                    tracing::trace!(
                        request_id = inner.context.id(),
                        "Processing common response: {:?}",
                        response
                    );

537
538
539
540
541
542
543
                    // Check if this response has a finish_reason
                    let has_finish_reason = response
                        .data
                        .as_ref()
                        .map(|d| d.finish_reason.is_some())
                        .unwrap_or(false);

544
545
546
547
548
549
550
551
552
553
554
555
556
557
                    let (chunk_tokens, isl) = if let Some(ref backend_output) = response.data {
                        let chunk_tokens = backend_output.token_ids.len();
                        inner.cumulative_output_tokens += chunk_tokens;

                        let isl = inner.response_generator.get_isl().unwrap_or(0) as usize;

                        (chunk_tokens, isl)
                    } else {
                        (0, 0)
                    };

                    let current_osl = inner.cumulative_output_tokens;

                    let mut response = response.map_data(|data| {
Biswa Panda's avatar
Biswa Panda committed
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
                        inner
                            .response_generator
                            .choice_from_postprocessor(data)
                            .inspect_err(|e| {
                                tracing::error!(
                                    request_id = inner.context.id(),
                                    "Error processing common response: {:?}",
                                    e
                                );
                                inner.cancelled = true;
                                inner.context.stop_generating();
                            })
                            .map_err(|e| e.to_string())
                    });

573
574
575
576
577
578
579
580
581
582
583
                    // Create LLM metrics annotation
                    let llm_metrics = LLMMetricAnnotation {
                        input_tokens: isl,
                        output_tokens: current_osl,
                        chunk_tokens,
                    };

                    if let Ok(metrics_annotated) = llm_metrics.to_annotation::<()>() {
                        // Only set event if not already set to avoid overriding existing events (like errors)
                        if response.event.is_none() {
                            response.event = metrics_annotated.event;
584
                            response.comment = metrics_annotated.comment;
585
586
                        }
                    }
587

588
589
590
591
592
                    // Mark if we've seen a finish_reason
                    if has_finish_reason {
                        inner.finish_reason_sent = true;
                    }

Biswa Panda's avatar
Biswa Panda committed
593
594
                    tracing::trace!(
                        request_id = inner.context.id(),
595
                        "OpenAI NvCreateChatCompletionStreamResponse: {:?}",
Biswa Panda's avatar
Biswa Panda committed
596
597
598
599
600
                        response
                    );

                    Some((response, inner))
                } else {
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
                    // Stream has ended - check if we need to send a usage chunk
                    if inner.response_generator.is_usage_enabled()
                        && inner.finish_reason_sent
                        && !inner.usage_chunk_sent
                        && !inner.finished
                    {
                        inner.usage_chunk_sent = true;

                        // Create the final usage chunk
                        let usage_chunk = inner.response_generator.create_usage_chunk();
                        let annotated_usage = Annotated::<Resp> {
                            id: None,
                            data: Some(usage_chunk),
                            event: Some(ANNOTATION_LLM_METRICS.to_string()),
                            comment: None,
                        };

                        tracing::trace!(
                            request_id = inner.context.id(),
                            "Sending final usage chunk for OpenAI compliance"
                        );

                        Some((annotated_usage, inner))
                    } else {
                        // stream closed
                        inner.finished = true; // Mark as finished
                        None
                    }
Biswa Panda's avatar
Biswa Panda committed
629
630
631
632
633
634
                }
            }
        });

        ResponseStream::new(Box::pin(stream), context)
    }
635
636
637
638
639
640
641
642
643
644
645

    /// Transform engine embedding output stream to OpenAI embedding response stream
    pub fn transform_embedding_postprocessor_stream(
        stream: ManyOut<Annotated<EmbeddingsEngineOutput>>,
        original_request: NvCreateEmbeddingRequest,
    ) -> ManyOut<Annotated<NvCreateEmbeddingResponse>> {
        let context = stream.context();

        let transformed_stream = stream.map(move |output| {
            output.map_data(|engine_output| {
                // Convert engine output to OpenAI response format
646
                let embeddings: Vec<dynamo_async_openai::types::Embedding> = engine_output
647
648
649
                    .embeddings
                    .into_iter()
                    .enumerate()
650
                    .map(|(index, embedding)| dynamo_async_openai::types::Embedding {
651
652
653
654
655
656
657
                        index: index as u32,
                        object: "embedding".to_string(),
                        embedding: embedding.into_iter().map(|f| f as f32).collect(),
                    })
                    .collect();

                let response = NvCreateEmbeddingResponse {
658
                    inner: dynamo_async_openai::types::CreateEmbeddingResponse {
659
660
661
                        object: "list".to_string(),
                        model: original_request.inner.model.clone(),
                        data: embeddings,
662
                        usage: dynamo_async_openai::types::EmbeddingUsage {
663
664
665
666
667
668
669
670
671
672
673
674
                            prompt_tokens: engine_output.prompt_tokens,
                            total_tokens: engine_output.total_tokens,
                        },
                    },
                };

                Ok(response)
            })
        });

        ResponseStream::new(Box::pin(transformed_stream), context)
    }
675
676

    /// Apply tool calling jail to the stream using the preprocessor's tool call parser
677
    pub async fn apply_tool_calling_jail_with_parser(
678
679
680
        &self,
        stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
    ) -> ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> {
681
        apply_tool_calling_jail_internal(stream, self.tool_call_parser.clone()).await
682
683
684
685
686
    }
}

/// Apply tool calling jail to the stream - stops/jails the stream under certain conditions
/// When jailed, the stream will be unjailed when the input stream ends
687
pub async fn apply_tool_calling_jail_internal(
688
689
690
691
692
693
694
695
696
697
698
699
700
    stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
    tool_call_parser: Option<String>,
) -> ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> {
    let context = stream.context();

    let jail_state = JailState {
        stream,
        is_jailed: false,
        tool_call_parser,
        accumulated_content: HashMap::new(),
        last_response_metadata: None,
        finished: false,
    };
701

702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
    // Transform the stream using unfold to maintain state
    // Input: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>
    // Returns None if the stream is finished
    // Returns Some((Annotated<NvCreateChatCompletionStreamResponse>, JailState)) if the stream is not finished
    // End output: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>
    let jailed_stream = stream::unfold(jail_state, |mut state| async move {
        // If already finished, return None immediately
        if state.finished {
            return None;
        }

        if let Some(response) = state.stream.next().await {
            // Check if we should jail the stream
            if !state.is_jailed {
                // Handle the case where response.data is Option<T>
                if let Some(ref chat_response) = response.data {
                    // Store metadata for potential tool call parsing later
                    state.last_response_metadata = Some(chat_response.clone());

                    // Extract text content from the response
                    if let Some(choice) = chat_response.choices.first()
                        && let Some(ref content) = choice.delta.content
                    {
                        // Check for tool call start
                        match detect_tool_call_start(content, state.tool_call_parser.as_deref()) {
                            Ok(should_jail) => {
                                if should_jail {
                                    tracing::debug!("Tool call detected, jailing stream");
                                    state.is_jailed = true;

                                    // Start accumulating content for this choice
                                    state
                                        .accumulated_content
                                        .insert(choice.index, content.clone());

                                    // Create possible tool call annotation with token information
                                    let possible_annotation = PossibleToolCallAnnotation {
                                        possible_tokens: 1, // This chunk contains tokens being processed
                                        possible_content: content.clone(),
                                        parser_used: state.tool_call_parser.clone(),
                                    };

                                    // Create annotated response instead of empty response
                                    let mut annotated_response = response.clone();
                                    if let Ok(possible_annotated) =
                                        possible_annotation
                                            .to_annotation::<NvCreateChatCompletionStreamResponse>()
                                    {
                                        // Set annotation event and comment
                                        annotated_response.event = possible_annotated.event;
                                        annotated_response.comment = possible_annotated.comment;
                                    }

                                    // Modify the response to have empty content but keep metadata
                                    annotated_response =
                                        annotated_response.map_data(|mut chat_response| {
                                            // Clear the content but keep choice structure for ITL measurement
                                            for choice in &mut chat_response.choices {
                                                choice.delta.content = Some(String::new()); // Empty content
                                            }
                                            Ok(chat_response)
                                        });

                                    return Some((annotated_response, state));
                                }
                            }
                            Err(e) => {
                                tracing::warn!("Error detecting tool call start: {}", e);
                            }
                        }
                    }
                }
            } else if state.is_jailed {
                // If already jailed, continue to jail but with annotations and accumulate content
                if let Some(ref chat_response) = response.data {
                    // Extract content for annotation and accumulation
                    for choice in &chat_response.choices {
                        if let Some(ref content) = choice.delta.content
                            && !content.is_empty()
                        {
                            // Accumulate content for this choice
                            state
                                .accumulated_content
                                .entry(choice.index)
                                .or_default()
                                .push_str(content);

                            // Create possible tool call annotation
                            let possible_annotation = PossibleToolCallAnnotation {
                                possible_tokens: 1,
                                possible_content: content.clone(),
                                parser_used: state.tool_call_parser.clone(),
                            };

                            // Create annotated response
                            let mut annotated_response = response.clone();
                            if let Ok(possible_annotated) = possible_annotation
                                .to_annotation::<NvCreateChatCompletionStreamResponse>(
                            ) {
                                annotated_response.event = possible_annotated.event;
                                annotated_response.comment = possible_annotated.comment;
                            }

                            // Clear content but keep structure
                            annotated_response =
                                annotated_response.map_data(|mut chat_response| {
                                    for choice in &mut chat_response.choices {
                                        choice.delta.content = Some(String::new());
                                    }
                                    Ok(chat_response)
                                });

                            return Some((annotated_response, state));
                        }
                    }
                }
            }

            // If not jailed or jailing condition not met, return the response as-is
            Some((response, state))
        } else {
            // Stream ended - if we were jailed, we should unjail now and parse tool calls
            if state.is_jailed {
                tracing::debug!("Stream ended, unjailing and parsing accumulated content");
                state.is_jailed = false;

                // Parse accumulated content for tool calls
                if !state.accumulated_content.is_empty()
                    && let Some(base_response) = state.last_response_metadata.take()
                {
                    // Try to parse tool calls from accumulated content for each choice
                    let mut final_response = base_response.clone();

                    for (choice_index, accumulated_text) in &state.accumulated_content {
                        if let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate(
                            accumulated_text,
                            state.tool_call_parser.as_deref(),
839
840
841
                        )
                        .await
                        {
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
                            // Found tool calls, create a final response with them
                            tracing::debug!(
                                "Parsed {} tool calls from accumulated content",
                                tool_calls.len()
                            );
                            for tool_call in &tool_calls {
                                tracing::debug!(
                                    tool_call_id = %tool_call.id,
                                    function_name = %tool_call.function.name,
                                    arguments = %tool_call.function.arguments,
                                    "Parsed structured tool call from accumulated content in jail"
                                );
                            }

                            // Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallChunk for streaming
                            let tool_call_chunks: Vec<
                                dynamo_async_openai::types::ChatCompletionMessageToolCallChunk,
                            > = tool_calls
                                .into_iter()
                                .enumerate()
                                .map(|(idx, tool_call)| {
                                    dynamo_async_openai::types::ChatCompletionMessageToolCallChunk {
                                        index: idx as u32,
                                        id: Some(tool_call.id),
                                        r#type: Some(tool_call.r#type),
                                        function: Some(
                                            dynamo_async_openai::types::FunctionCallStream {
                                                name: Some(tool_call.function.name),
                                                arguments: Some(tool_call.function.arguments),
                                            },
                                        ),
                                    }
                                })
                                .collect();

                            // Create a choice with tool calls
                            #[allow(deprecated)]
                            let final_choice = dynamo_async_openai::types::ChatChoiceStream {
                                index: *choice_index,
                                delta:
                                    dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
                                        role: Some(dynamo_async_openai::types::Role::Assistant),
                                        content: normal_text.filter(|t| !t.is_empty()),
                                        tool_calls: Some(tool_call_chunks.clone()),
                                        function_call: None,
                                        refusal: None,
                                        reasoning_content: None,
                                    },
                                finish_reason: Some(
                                    dynamo_async_openai::types::FinishReason::ToolCalls,
                                ),
                                logprobs: None,
                            };

                            // Update the response choices
                            final_response.choices = vec![final_choice];

                            // Create final annotated response
                            let final_annotated = Annotated {
                                data: Some(final_response),
                                id: None,
                                event: None,
                                comment: None,
                            };

                            state.finished = true; // Mark as finished before returning
                            return Some((final_annotated, state));
                        }
                    }
                }
            }
            state.finished = true; // Mark as finished
            None
        }
    });

    // Jailed Stream contains empty content chunks with annotation event "possible_tool_call" whenever the stream is jailed
    // This is a bad UX for the user, as they have to see a lot of empty content chunks
    // Filter out the empty content chunks with annotation event "possible_tool_call"
    let filtered_stream = jailed_stream.filter(|annotated| {
        let keep = annotated.event.as_deref() != Some(ANNOTATION_POSSIBLE_TOOL_CALL);
        async move { keep }
    });

    ResponseStream::new(Box::pin(filtered_stream), context)
Biswa Panda's avatar
Biswa Panda committed
927
928
929
930
931
932
933
934
935
936
}

// for pals, we do not want to add the generation prompt to the formatted prompt
// we also need to know if the template support this add_generation_prompt bool
// any prompt template that does not support this should return an error
// oob - we should update any prompt template that does not support this to support it

#[async_trait]
impl
    Operator<
937
        SingleIn<NvCreateChatCompletionRequest>,
938
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
939
        SingleIn<PreprocessedRequest>,
Biswa Panda's avatar
Biswa Panda committed
940
941
942
943
944
        ManyOut<Annotated<BackendOutput>>,
    > for OpenAIPreprocessor
{
    async fn generate(
        &self,
945
        request: SingleIn<NvCreateChatCompletionRequest>,
Biswa Panda's avatar
Biswa Panda committed
946
        next: Arc<
947
            dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
Biswa Panda's avatar
Biswa Panda committed
948
        >,
949
    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
Biswa Panda's avatar
Biswa Panda committed
950
951
952
953
        // unpack the request
        let (request, context) = request.into_parts();

        // create a response generator
954
        let response_generator = request.response_generator(context.id().to_string());
Biswa Panda's avatar
Biswa Panda committed
955
956
        let mut response_generator = Box::new(response_generator);

957
958
        // set the runtime configuration
        response_generator.set_reasoning_parser(self.runtime_config.clone());
959
960
        let enable_tool_calling =
            maybe_enable_tool_call(self.tool_call_parser.as_deref(), &request);
Biswa Panda's avatar
Biswa Panda committed
961
962
963
964
        // convert the chat completion request to a common completion request
        let (common_request, annotations) = self.preprocess_request(&request)?;

        // update isl
Paul Hendricks's avatar
Paul Hendricks committed
965
        response_generator.update_isl(common_request.token_ids.len() as u32);
Biswa Panda's avatar
Biswa Panda committed
966
967
968
969
970

        // repack the common completion request
        let common_request = context.map(|_| common_request);

        // create a stream of annotations this will be prepend to the response stream
971
        let annotations: Vec<Annotated<NvCreateChatCompletionStreamResponse>> = annotations
Biswa Panda's avatar
Biswa Panda committed
972
973
974
975
976
977
978
979
980
981
982
            .into_iter()
            .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
            .collect();
        let annotations_stream = stream::iter(annotations);

        // forward the common completion request to the next operator
        let response_stream = next.generate(common_request).await?;

        // transform the postprocessor stream
        let stream = Self::transform_postprocessor_stream(response_stream, response_generator);

983
984
985
986
987
988
989
        // Apply tool calling jail to the stream if tool call parser is present
        let stream = if enable_tool_calling {
            self.apply_tool_calling_jail_with_parser(stream).await
        } else {
            stream
        };

990
        let context = stream.context();
Biswa Panda's avatar
Biswa Panda committed
991
992
993
994
995
996
997
998
999
1000
1001
        // prepend the annotations to the response stream
        let stream = annotations_stream.chain(stream);

        // return the response stream
        Ok(ResponseStream::new(Box::pin(stream), context))
    }
}

#[async_trait]
impl
    Operator<
1002
        SingleIn<NvCreateCompletionRequest>,
1003
        ManyOut<Annotated<NvCreateCompletionResponse>>,
1004
        SingleIn<PreprocessedRequest>,
Biswa Panda's avatar
Biswa Panda committed
1005
1006
1007
1008
1009
        ManyOut<Annotated<BackendOutput>>,
    > for OpenAIPreprocessor
{
    async fn generate(
        &self,
1010
        request: SingleIn<NvCreateCompletionRequest>,
Biswa Panda's avatar
Biswa Panda committed
1011
        next: Arc<
1012
            dyn AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<BackendOutput>>, Error>,
Biswa Panda's avatar
Biswa Panda committed
1013
        >,
1014
    ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
Biswa Panda's avatar
Biswa Panda committed
1015
1016
1017
1018
        // unpack the request
        let (request, context) = request.into_parts();

        // create a response generator
1019
        let response_generator = request.response_generator(context.id().to_string());
Biswa Panda's avatar
Biswa Panda committed
1020
1021
        let mut response_generator = Box::new(response_generator);
        // convert the chat completion request to a common completion request
1022
1023
1024
        let mut builder = self.builder(&request)?;
        let annotations = self.gather_tokens(&request, &mut builder, None)?;
        let common_request = builder.build()?;
Biswa Panda's avatar
Biswa Panda committed
1025
1026

        // update isl
1027
        response_generator.update_isl(common_request.token_ids.len() as u32);
Biswa Panda's avatar
Biswa Panda committed
1028
1029
1030
1031
1032

        // repack the common completion request
        let common_request = context.map(|_| common_request);

        // create a stream of annotations this will be prepend to the response stream
1033
        let annotations: Vec<Annotated<NvCreateCompletionResponse>> = annotations
Biswa Panda's avatar
Biswa Panda committed
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
            .into_iter()
            .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
            .collect();
        let annotations_stream = stream::iter(annotations);

        // forward the common completion request to the next operator
        let response_stream = next.generate(common_request).await?;

        // transform the postprocessor stream
        let stream = Self::transform_postprocessor_stream(response_stream, response_generator);
        let context = stream.context();

        // prepend the annotations to the response stream
        let stream = annotations_stream.chain(stream);

        // return the response stream
        Ok(ResponseStream::new(Box::pin(stream), context))
    }
}
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067

#[async_trait]
impl
    Operator<
        SingleIn<NvCreateEmbeddingRequest>,
        ManyOut<Annotated<NvCreateEmbeddingResponse>>,
        SingleIn<PreprocessedEmbeddingRequest>,
        ManyOut<Annotated<EmbeddingsEngineOutput>>,
    > for OpenAIPreprocessor
{
    async fn generate(
        &self,
        request: SingleIn<NvCreateEmbeddingRequest>,
        next: Arc<
            dyn AsyncEngine<
1068
1069
1070
1071
                    SingleIn<PreprocessedEmbeddingRequest>,
                    ManyOut<Annotated<EmbeddingsEngineOutput>>,
                    Error,
                >,
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
        >,
    ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
        // Unpack request
        let (request, context) = request.into_parts();

        // Preprocess the embedding request
        let (preprocessed_request, annotations) =
            self.preprocess_embedding_request(&request).await?;

        // Forward to next stage
        let preprocessed_request = context.map(|_| preprocessed_request);
        let response_stream = next.generate(preprocessed_request).await?;

        // Transform response stream back to OpenAI format
        let stream = Self::transform_embedding_postprocessor_stream(response_stream, request);
        let context = stream.context();

        // Prepend annotations
        let annotations_stream = stream::iter(
            annotations
                .into_iter()
                .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
                .collect::<Vec<_>>(),
        );

        let combined_stream = annotations_stream.chain(stream);
        Ok(ResponseStream::new(Box::pin(combined_stream), context))
    }
}