delta.rs 11.3 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
5
6
7
8
9
10
11
use crate::{
    protocols::{
        common,
        openai::nvext::{NvExtResponse, WorkerIdInfo},
    },
    types::TokenIdType,
};
12

13
impl NvCreateCompletionRequest {
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
    ///
    /// According to OpenAI API spec, non-streaming completion responses (stream=false)
    /// must always include usage statistics. This method ensures `stream_options.include_usage`
    /// is set to `true` for non-streaming requests.
    ///
    /// Reference: https://platform.openai.com/docs/api-reference/completions/create
    ///
    /// # Arguments
    /// * `original_stream_flag` - The original value of the `stream` field before any internal processing
    pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
        if !original_stream_flag {
            // For non-streaming requests (stream=false), enable usage by default
            if self.inner.stream_options.is_none() {
                self.inner.stream_options =
                    Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
                        include_usage: true,
                    });
            } else if let Some(ref mut opts) = self.inner.stream_options {
                // If stream_options exists, ensure include_usage is true for non-streaming
                opts.include_usage = true;
            }
        }
    }

39
40
    // put this method on the request
    // inspect the request to extract options
41
    pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
42
        let options = DeltaGeneratorOptions {
43
44
45
46
47
48
            enable_usage: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.include_usage)
                .unwrap_or(false),
Greg Clark's avatar
Greg Clark committed
49
            enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
50
51
        };

52
        DeltaGenerator::new(self.inner.model.clone(), options, request_id)
53
54
55
56
57
58
59
60
61
62
63
64
65
    }
}

#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
    pub enable_usage: bool,
    pub enable_logprobs: bool,
}

#[derive(Debug, Clone)]
pub struct DeltaGenerator {
    id: String,
    object: String,
66
    created: u32,
67
68
    model: String,
    system_fingerprint: Option<String>,
69
    usage: dynamo_async_openai::types::CompletionUsage,
70
71
72
73
    options: DeltaGeneratorOptions,
}

impl DeltaGenerator {
74
    pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
75
76
77
78
79
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
            .as_secs();

80
81
82
83
        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");

84
85
        // Previously, our home-rolled CompletionUsage impl'd Default
        // PR !387 - https://github.com/64bit/async-openai/pull/387
86
        let usage = dynamo_async_openai::types::CompletionUsage {
87
88
89
90
91
92
93
            completion_tokens: 0,
            prompt_tokens: 0,
            total_tokens: 0,
            completion_tokens_details: None,
            prompt_tokens_details: None,
        };

94
95
        let completion_id = format!("cmpl-{request_id}");

96
        Self {
97
            id: completion_id,
98
99
100
101
            object: "text_completion".to_string(),
            created: now,
            model,
            system_fingerprint: None,
102
            usage,
103
104
105
106
            options,
        }
    }

107
    pub fn update_isl(&mut self, isl: u32) {
108
109
110
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
111
112
113
114
115
116
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
        token_ids: Vec<TokenIdType>,
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
117
    ) -> Option<dynamo_async_openai::types::Logprobs> {
Greg Clark's avatar
Greg Clark committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
            .map(|(token, token_id)| (token.unwrap_or_default(), token_id))
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let top_lps = top_logprobs.map_or(vec![], |top_logprobs| {
            toks.iter()
                .zip(tok_lps.iter())
                .zip(top_logprobs.iter())
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
145
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
146
147
148
149
150
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
151
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
152
153
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
154
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
155
156
157
158
159
160
161
162
163
164
                            token: t.clone(),
                            logprob: *lp,
                            bytes: None,
                        });
                    }
                    serde_json::to_value(converted_top_lps).unwrap()
                })
                .collect()
        });

165
        Some(dynamo_async_openai::types::Logprobs {
Greg Clark's avatar
Greg Clark committed
166
167
168
169
170
171
172
            tokens: toks.iter().map(|(t, _)| t.clone()).collect(),
            token_logprobs: tok_lps.into_iter().map(Some).collect(),
            text_offset: vec![],
            top_logprobs: top_lps,
        })
    }

173
174
    pub fn create_choice(
        &self,
175
        index: u32,
176
        text: Option<String>,
177
178
        finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason>,
        logprobs: Option<dynamo_async_openai::types::Logprobs>,
179
    ) -> NvCreateCompletionResponse {
180
181
        // todo - update for tool calling

182
183
184
        // According to OpenAI spec: when stream_options.include_usage is true,
        // all intermediate chunks should have usage: null
        // The final usage chunk will be sent separately with empty choices
185
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
186
187
            id: self.id.clone(),
            object: self.object.clone(),
188
            created: self.created,
189
190
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
191
            choices: vec![dynamo_async_openai::types::Choice {
192
                text: text.unwrap_or_default(),
193
                index,
194
                finish_reason,
Greg Clark's avatar
Greg Clark committed
195
                logprobs,
196
            }],
197
            usage: None, // Always None for chunks with content/choices
198
            nvext: None, // Will be populated by router layer if needed
199
200
201
        };

        NvCreateCompletionResponse { inner }
202
    }
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

    /// Creates a final usage-only chunk for OpenAI compliance.
    /// This should be sent after the last content chunk when stream_options.include_usage is true.
    ///
    /// # Returns
    /// * A [`NvCreateCompletionResponse`] with empty choices and usage stats.
    pub fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
        let mut usage = self.usage.clone();
        usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);

        let inner = dynamo_async_openai::types::CreateCompletionResponse {
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
            choices: vec![], // Empty choices for usage-only chunk
            usage: Some(usage),
221
            nvext: None, // Will be populated by router layer if needed
222
223
224
225
226
227
228
229
230
        };

        NvCreateCompletionResponse { inner }
    }

    /// Check if usage tracking is enabled
    pub fn is_usage_enabled(&self) -> bool {
        self.options.enable_usage
    }
231
232
}

233
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
234
235
236
    fn choice_from_postprocessor(
        &mut self,
        delta: common::llm_backend::BackendOutput,
237
    ) -> anyhow::Result<NvCreateCompletionResponse> {
238
239
        // aggregate usage
        if self.options.enable_usage {
240
241
242
243
244
245
246
247
248
            // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
            // but this will not be an issue until context lengths exceed 4_294_967_295.
            let token_length: u32 = delta
                .token_ids
                .len()
                .try_into()
                .expect("token_ids length exceeds u32::MAX");

            self.usage.completion_tokens += token_length;
249
250
251
252
253
254
255
256
257
258

            // If backend provides completion_usage with prompt token details,
            // propagate the entire details struct to usage tracking
            if let Some(prompt_details) = delta
                .completion_usage
                .as_ref()
                .and_then(|usage| usage.prompt_tokens_details.as_ref())
            {
                self.usage.prompt_tokens_details = Some(prompt_details.clone());
            }
259
260
        }

Greg Clark's avatar
Greg Clark committed
261
262
263
264
265
266
        let logprobs = self.create_logprobs(
            delta.tokens,
            delta.token_ids,
            delta.log_probs,
            delta.top_logprobs,
        );
267
268

        let finish_reason = delta.finish_reason.map(Into::into);
269
270

        // create choice
271
        let index = delta.index.unwrap_or(0);
272
273
274
        let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);

        // Extract worker_id from disaggregated_params and inject into nvext if present
275
        if let Some(worker_id_info) = delta
276
277
278
            .disaggregated_params
            .as_ref()
            .and_then(|params| params.get("worker_id"))
279
            .and_then(|v| serde_json::from_value::<WorkerIdInfo>(v.clone()).ok())
280
281
        {
            let nvext_response = NvExtResponse {
282
                worker_id: Some(worker_id_info.clone()),
283
284
285
286
287
288
            };

            if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
                response.inner.nvext = Some(nvext_json);
                tracing::debug!(
                    "Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
289
290
                    worker_id_info.prefill_worker_id,
                    worker_id_info.decode_worker_id
291
292
293
294
                );
            }
        }

295
        Ok(response)
296
    }
297
298

    fn get_isl(&self) -> Option<u32> {
299
        Some(self.usage.prompt_tokens)
300
    }
301
302
303
304
305
306
307
308

    fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
        DeltaGenerator::create_usage_chunk(self)
    }

    fn is_usage_enabled(&self) -> bool {
        DeltaGenerator::is_usage_enabled(self)
    }
309
}