"tests/kernels/attention/test_mha_attn.py" did not exist on "aa2cd2c43d1d19ece0f3b36ad716c3a9b8a2def0"
delta.rs 21.5 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
use std::sync::Arc;

6
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
Greg Clark's avatar
Greg Clark committed
7
use crate::{
8
    local_model::runtime_config::ModelRuntimeConfig,
9
    protocols::{
10
11
        common::{self, timing::RequestTracker},
        openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo},
12
    },
Greg Clark's avatar
Greg Clark committed
13
14
    types::TokenIdType,
};
15

16
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
17
impl NvCreateChatCompletionRequest {
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
    ///
    /// According to OpenAI API spec, non-streaming chat completion responses (stream=false)
    /// must always include usage statistics. This method ensures `stream_options.include_usage`
    /// is set to `true` for non-streaming requests.
    ///
    /// # Arguments
    /// * `original_stream_flag` - The original value of the `stream` field before any internal processing
    pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
        if !original_stream_flag {
            // For non-streaming requests (stream=false), enable usage by default
            if self.inner.stream_options.is_none() {
                self.inner.stream_options =
                    Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
                        include_usage: true,
33
                        continuous_usage_stats: false,
34
35
36
37
38
39
40
41
                    });
            } else if let Some(ref mut opts) = self.inner.stream_options {
                // If stream_options exists, ensure include_usage is true for non-streaming
                opts.include_usage = true;
            }
        }
    }

42
43
    /// Creates a [`DeltaGenerator`] instance based on the chat completion request.
    ///
44
45
46
    /// # Arguments
    /// * `request_id` - The request ID to use for the chat completion response ID.
    ///
47
48
    /// # Returns
    /// * [`DeltaGenerator`] configured with model name and response options.
49
    pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
50
51
52
53
        // Enable tracking if:
        // 1. Client requested timing in extra_fields, OR
        // 2. query_instance_id annotation is present (needs worker_id tracking for response)
        let enable_tracking = self
54
            .nvext()
55
56
57
58
59
60
61
62
63
            .map(|nv| {
                nv.extra_fields
                    .as_ref()
                    .is_some_and(|fields| fields.iter().any(|f| f == "timing"))
                    || nv.annotations.as_ref().is_some_and(|annots| {
                        annots.iter().any(|a| a.starts_with("query_instance_id"))
                    })
            })
            .unwrap_or(false);
64

65
        let options = DeltaGeneratorOptions {
66
67
68
69
70
71
            enable_usage: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.include_usage)
                .unwrap_or(false),
72
73
74
75
76
77
            continuous_usage_stats: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.continuous_usage_stats)
                .unwrap_or(false),
Greg Clark's avatar
Greg Clark committed
78
79
            enable_logprobs: self.inner.logprobs.unwrap_or(false)
                || self.inner.top_logprobs.unwrap_or(0) > 0,
80
            enable_tracking,
81
            runtime_config: ModelRuntimeConfig::default(),
82
83
        };

84
        DeltaGenerator::new(self.inner.model.clone(), options, request_id)
85
86
87
    }
}

88
/// Configuration options for the [`DeltaGenerator`], controlling response behavior.
89
90
#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
91
    /// Determines whether token usage statistics should be included in the response.
92
    pub enable_usage: bool,
93
94
    /// Determines whether continuous usage statistics should be included in the response.
    pub continuous_usage_stats: bool,
95
    /// Determines whether log probabilities should be included in the response.
96
    pub enable_logprobs: bool,
97
98
    /// Determines whether request tracking (timing, KV hit rate) should be enabled.
    pub enable_tracking: bool,
99

100
    pub runtime_config: ModelRuntimeConfig,
101
102
}

103
/// Generates incremental chat completion responses in a streaming fashion.
104
pub struct DeltaGenerator {
105
    /// Unique identifier for the chat completion session.
106
    id: String,
107
    /// Object type, representing a streamed chat completion response.
108
    object: String,
109
    /// Timestamp (Unix epoch) when the response was created.
Paul Hendricks's avatar
Paul Hendricks committed
110
    created: u32,
111
    model: String,
112
    /// Optional system fingerprint for version tracking.
113
    system_fingerprint: Option<String>,
114
    /// Optional service tier information for the response.
115
    service_tier: Option<dynamo_async_openai::types::ServiceTierResponse>,
116
    /// Tracks token usage for the completion request.
117
    usage: dynamo_async_openai::types::CompletionUsage,
118
    /// Counter tracking the number of messages issued.
119
    msg_counter: u64,
120
    /// Configuration options for response generation.
121
    options: DeltaGeneratorOptions,
122
123
    /// Optional request tracker for per-request metrics (shared with PreprocessedRequest).
    tracker: Option<Arc<RequestTracker>>,
124
125
126
}

impl DeltaGenerator {
127
128
129
130
131
    /// Creates a new [`DeltaGenerator`] instance with the specified model and options.
    ///
    /// # Arguments
    /// * `model` - The model name used for response generation.
    /// * `options` - Configuration options for enabling usage and log probabilities.
132
    /// * `request_id` - The request ID to use for the chat completion response.
133
134
135
    ///
    /// # Returns
    /// * A new instance of [`DeltaGenerator`].
136
    pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
137
138
139
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
140
141
142
143
144
            .as_secs();

        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");
Paul Hendricks's avatar
Paul Hendricks committed
145

146
        let usage = dynamo_async_openai::types::CompletionUsage {
Paul Hendricks's avatar
Paul Hendricks committed
147
148
149
150
151
152
            prompt_tokens: 0,
            completion_tokens: 0,
            total_tokens: 0,
            prompt_tokens_details: None,
            completion_tokens_details: None,
        };
153

154
155
        let chatcmpl_id = format!("chatcmpl-{request_id}");

156
157
158
        // Create request tracker if tracking is enabled
        let tracker = if options.enable_tracking {
            Some(Arc::new(RequestTracker::new()))
159
160
161
162
        } else {
            None
        };

163
        Self {
164
            id: chatcmpl_id,
165
166
167
168
169
            object: "chat.completion.chunk".to_string(),
            created: now,
            model,
            system_fingerprint: None,
            service_tier: None,
Paul Hendricks's avatar
Paul Hendricks committed
170
            usage,
171
172
            msg_counter: 0,
            options,
173
            tracker,
174
175
176
        }
    }

177
178
179
180
181
    /// Returns the request tracker if tracking is enabled, for sharing with PreprocessedRequest.
    pub fn tracker(&self) -> Option<Arc<RequestTracker>> {
        self.tracker.clone()
    }

182
183
184
185
    /// Updates the prompt token usage count.
    ///
    /// # Arguments
    /// * `isl` - The number of prompt tokens used.
Paul Hendricks's avatar
Paul Hendricks committed
186
    pub fn update_isl(&mut self, isl: u32) {
187
188
189
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
190
191
192
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
193
        token_ids: &[TokenIdType],
Greg Clark's avatar
Greg Clark committed
194
195
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
196
    ) -> Option<dynamo_async_openai::types::ChatChoiceLogprobs> {
Greg Clark's avatar
Greg Clark committed
197
198
199
200
201
202
203
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
204
            .map(|(token, token_id)| (token.unwrap_or_default(), *token_id))
Greg Clark's avatar
Greg Clark committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let content = top_logprobs.map(|top_logprobs| {
            toks.iter()
                .zip(tok_lps)
                .zip(top_logprobs)
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
224
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
225
226
227
228
229
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
230
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
231
232
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
233
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
234
235
236
237
238
                            token: t.clone(),
                            logprob: lp,
                            bytes: None,
                        });
                    }
239
                    dynamo_async_openai::types::ChatCompletionTokenLogprob {
Greg Clark's avatar
Greg Clark committed
240
241
242
243
244
245
246
247
248
                        token: t.clone(),
                        logprob: lp,
                        bytes: None,
                        top_logprobs: converted_top_lps,
                    }
                })
                .collect()
        });

249
        Some(dynamo_async_openai::types::ChatChoiceLogprobs {
Greg Clark's avatar
Greg Clark committed
250
251
252
253
254
            content,
            refusal: None,
        })
    }

255
256
257
258
259
260
261
    /// Creates a choice within a chat completion response.
    ///
    /// # Arguments
    /// * `index` - The index of the choice in the completion response.
    /// * `text` - The text content for the response.
    /// * `finish_reason` - The reason why the response finished (e.g., stop, length, etc.).
    /// * `logprobs` - Optional log probabilities of the generated tokens.
262
    /// * `stop_reason` - Optional stop string or token that triggered the stop.
263
264
    ///
    /// # Returns
265
    /// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice.
Paul Hendricks's avatar
Paul Hendricks committed
266
    #[allow(deprecated)]
267
    pub fn create_choice(
268
        &mut self,
Paul Hendricks's avatar
Paul Hendricks committed
269
        index: u32,
270
        text: Option<String>,
271
272
        finish_reason: Option<dynamo_async_openai::types::FinishReason>,
        logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
273
        stop_reason: Option<dynamo_async_openai::types::StopReason>,
274
    ) -> NvCreateChatCompletionStreamResponse {
275
        let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
276
            content: text,
277
278
            function_call: None,
            tool_calls: None,
279
            role: if self.msg_counter == 0 {
280
                Some(dynamo_async_openai::types::Role::Assistant)
281
282
283
            } else {
                None
            },
Paul Hendricks's avatar
Paul Hendricks committed
284
            refusal: None,
285
            reasoning_content: None,
286
287
        };

288
        let choice = dynamo_async_openai::types::ChatChoiceStream {
Paul Hendricks's avatar
Paul Hendricks committed
289
290
291
            index,
            delta,
            finish_reason,
292
            stop_reason,
Paul Hendricks's avatar
Paul Hendricks committed
293
294
295
296
297
            logprobs,
        };

        let choices = vec![choice];

298
299
300
301
302
303
304
305
306
307
        // According to OpenAI spec: when stream_options.include_usage is true,
        // all intermediate chunks should have usage: null
        // The final usage chunk will be sent separately with empty choices
        dynamo_async_openai::types::CreateChatCompletionStreamResponse {
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
            choices,
308
309
310
311
312
            usage: if self.options.enable_usage && self.options.continuous_usage_stats {
                Some(self.get_usage())
            } else {
                None
            },
313
            service_tier: self.service_tier.clone(),
314
            nvext: None, // Will be populated by router layer if needed
315
        }
316
317
318
319
320
321
322
323
    }

    /// Creates a final usage-only chunk for OpenAI compliance.
    /// This should be sent after the last content chunk when stream_options.include_usage is true.
    ///
    /// # Returns
    /// * A [`CreateChatCompletionStreamResponse`] with empty choices and usage stats.
    pub fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
324
        let usage = self.get_usage();
325

326
        dynamo_async_openai::types::CreateChatCompletionStreamResponse {
327
328
329
330
331
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
332
333
            choices: vec![], // Empty choices for usage-only chunk
            usage: Some(usage),
334
            service_tier: self.service_tier.clone(),
335
            nvext: None,
336
337
        }
    }
338
339
340
341
342

    /// Check if usage tracking is enabled
    pub fn is_usage_enabled(&self) -> bool {
        self.options.enable_usage
    }
343

344
345
346
347
348
    /// Check if continuous usage tracking is enabled
    pub fn is_continuous_usage_enabled(&self) -> bool {
        self.options.continuous_usage_stats
    }

349
350
351
352
353
    pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        let mut usage = self.usage.clone();
        usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
        usage
    }
354
355
}

356
/// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing
357
/// it to transform backend responses into OpenAI-style streaming responses.
358
359
360
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse>
    for DeltaGenerator
{
361
362
363
364
365
366
367
368
    /// Converts a backend response into a structured OpenAI-style streaming response.
    ///
    /// # Arguments
    /// * `delta` - The backend response containing generated text and metadata.
    ///
    /// # Returns
    /// * `Ok(NvCreateChatCompletionStreamResponse)` if conversion succeeds.
    /// * `Err(anyhow::Error)` if an error occurs.
369
370
371
    fn choice_from_postprocessor(
        &mut self,
        delta: crate::protocols::common::llm_backend::BackendOutput,
372
    ) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        // Aggregate token usage even if usage tracking is disabled for metrics tracking
        // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until context lengths exceed 4_294_967_295.
        let token_length: u32 = delta
            .token_ids
            .len()
            .try_into()
            .expect("token_ids length exceeds u32::MAX");

        self.usage.completion_tokens += token_length;

        // If backend provides completion_usage with prompt token details,
        // propagate the entire details struct to usage tracking
        if let Some(prompt_details) = delta
            .completion_usage
            .as_ref()
            .and_then(|usage| usage.prompt_tokens_details.as_ref())
        {
            self.usage.prompt_tokens_details = Some(prompt_details.clone());
392
393
        }

Greg Clark's avatar
Greg Clark committed
394
395
        let logprobs = self.create_logprobs(
            delta.tokens,
396
            &delta.token_ids,
Greg Clark's avatar
Greg Clark committed
397
398
399
            delta.log_probs,
            delta.top_logprobs,
        );
400

401
        // Map backend finish reasons to OpenAI's finish reasons.
402
        let finish_reason = match delta.finish_reason {
403
404
405
406
407
408
409
410
411
412
            Some(common::FinishReason::EoS) => Some(dynamo_async_openai::types::FinishReason::Stop),
            Some(common::FinishReason::Stop) => {
                Some(dynamo_async_openai::types::FinishReason::Stop)
            }
            Some(common::FinishReason::Length) => {
                Some(dynamo_async_openai::types::FinishReason::Length)
            }
            Some(common::FinishReason::Cancelled) => {
                Some(dynamo_async_openai::types::FinishReason::Stop)
            }
413
            Some(common::FinishReason::ContentFilter) => {
414
                Some(dynamo_async_openai::types::FinishReason::ContentFilter)
415
            }
416
417
418
419
420
421
            Some(common::FinishReason::Error(err_msg)) => {
                return Err(anyhow::anyhow!(err_msg));
            }
            None => None,
        };

422
        // Create the streaming response.
423
        let index = 0;
424
425
426
427
428
429
430
        let mut stream_response = self.create_choice(
            index,
            delta.text,
            finish_reason,
            logprobs,
            delta.stop_reason,
        );
431

432
        // Record first token time (only succeeds on first call due to OnceLock)
433
        if let Some(ref tracker) = self.tracker {
434
435
436
            tracker.record_first_token();
        }

437
438
        // Get worker_id info from tracker (set by KvPushRouter based on phase)
        let worker_id_info = self.tracker.as_ref().and_then(|t| t.get_worker_info());
439

440
441
442
443
444
445
        let token_ids = delta
            .disaggregated_params
            .as_ref()
            .and_then(|params| params.get("token_ids"))
            .and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());

446
447
        // Get timing info if this is the final response (has finish_reason)
        let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
448
            self.tracker.as_ref().map(|tracker| {
449
450
451
452
453
454
455
                tracker.record_finish();
                tracker.get_timing_info()
            })
        } else {
            None
        };

456
457
        // Inject nvext if we have worker_id, token_ids, or timing
        if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() {
458
            let nvext_response = NvExtResponse {
459
460
                worker_id: worker_id_info.clone(),
                timing: timing_info,
461
                token_ids: token_ids.clone(),
462
463
464
465
            };

            if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
                stream_response.nvext = Some(nvext_json);
466
467
468
469
470
471
472
                if let Some(ref info) = worker_id_info {
                    tracing::debug!(
                        "Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}",
                        info.prefill_worker_id,
                        info.decode_worker_id
                    );
                }
473
474
475
476
477
478
                if let Some(ref tokens) = token_ids {
                    tracing::debug!(
                        "Injected token_ids into chat completion nvext: {} tokens",
                        tokens.len()
                    );
                }
479
480
            }
        }
Paul Hendricks's avatar
Paul Hendricks committed
481

482
        Ok(stream_response)
483
    }
484
485
486
487

    fn get_isl(&self) -> Option<u32> {
        Some(self.usage.prompt_tokens)
    }
488
489
490
491
492
493
494
495

    fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
        DeltaGenerator::create_usage_chunk(self)
    }

    fn is_usage_enabled(&self) -> bool {
        DeltaGenerator::is_usage_enabled(self)
    }
496

497
498
499
500
    fn is_continuous_usage_enabled(&self) -> bool {
        DeltaGenerator::is_continuous_usage_enabled(self)
    }

501
502
503
    fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
        DeltaGenerator::get_usage(self)
    }
504
}
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532

#[cfg(test)]
mod tests {
    use super::*;
    use dynamo_async_openai::types::{
        ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
        ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
    };

    fn create_test_request() -> NvCreateChatCompletionRequest {
        let messages = vec![ChatCompletionRequestMessage::User(
            ChatCompletionRequestUserMessage {
                content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
                name: None,
            },
        )];

        NvCreateChatCompletionRequest {
            inner: CreateChatCompletionRequest {
                model: "test-model".to_string(),
                messages,
                stream: Some(false),
                stream_options: None,
                ..Default::default()
            },
            common: Default::default(),
            nvext: None,
            chat_template_args: None,
533
            media_io_kwargs: None,
534
            unsupported_fields: Default::default(),
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        }
    }

    #[test]
    fn test_enable_usage_for_nonstreaming_enables_usage() {
        // Test that non-streaming requests get usage enabled
        let mut request = create_test_request();
        assert!(request.inner.stream_options.is_none());

        request.enable_usage_for_nonstreaming(false); // false = non-streaming

        assert!(
            request.inner.stream_options.is_some(),
            "Non-streaming request should have stream_options created"
        );
        assert!(
            request.inner.stream_options.unwrap().include_usage,
            "Non-streaming request should have include_usage=true for OpenAI compliance"
        );
554
555
556
557
        assert!(
            !request.inner.stream_options.unwrap().continuous_usage_stats,
            "Non-streaming request should have continuous_usage_stats=false for OpenAI compliance"
        );
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
    }

    #[test]
    fn test_enable_usage_for_nonstreaming_ignores_streaming() {
        // Test that streaming requests are not modified
        let mut request = create_test_request();
        assert!(request.inner.stream_options.is_none());

        request.enable_usage_for_nonstreaming(true); // true = streaming

        assert!(
            request.inner.stream_options.is_none(),
            "Streaming request should not have stream_options modified"
        );
    }
}