delta.rs 11.6 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
Greg Clark's avatar
Greg Clark committed
5
use crate::{protocols::common, types::TokenIdType};
6

7
impl NvCreateCompletionRequest {
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
    ///
    /// According to OpenAI API spec, non-streaming completion responses (stream=false)
    /// must always include usage statistics. This method ensures `stream_options.include_usage`
    /// is set to `true` for non-streaming requests.
    ///
    /// Reference: https://platform.openai.com/docs/api-reference/completions/create
    ///
    /// # Arguments
    /// * `original_stream_flag` - The original value of the `stream` field before any internal processing
    pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
        if !original_stream_flag {
            // For non-streaming requests (stream=false), enable usage by default
            if self.inner.stream_options.is_none() {
                self.inner.stream_options =
                    Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
                        include_usage: true,
                    });
            } else if let Some(ref mut opts) = self.inner.stream_options {
                // If stream_options exists, ensure include_usage is true for non-streaming
                opts.include_usage = true;
            }
        }
    }

33
34
    // put this method on the request
    // inspect the request to extract options
35
    pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
36
        let options = DeltaGeneratorOptions {
37
38
39
40
41
42
            enable_usage: self
                .inner
                .stream_options
                .as_ref()
                .map(|opts| opts.include_usage)
                .unwrap_or(false),
Greg Clark's avatar
Greg Clark committed
43
            enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
44
45
        };

46
        DeltaGenerator::new(self.inner.model.clone(), options, request_id)
47
48
49
50
51
52
53
54
55
56
57
58
59
    }
}

#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
    pub enable_usage: bool,
    pub enable_logprobs: bool,
}

#[derive(Debug, Clone)]
pub struct DeltaGenerator {
    id: String,
    object: String,
60
    created: u32,
61
62
    model: String,
    system_fingerprint: Option<String>,
63
    usage: dynamo_async_openai::types::CompletionUsage,
64
65
66
67
    options: DeltaGeneratorOptions,
}

impl DeltaGenerator {
68
    pub fn new(model: String, options: DeltaGeneratorOptions, request_id: String) -> Self {
69
70
71
72
73
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
            .as_secs();

74
75
76
77
        // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
        // but this will not be an issue until 2106.
        let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");

78
79
        // Previously, our home-rolled CompletionUsage impl'd Default
        // PR !387 - https://github.com/64bit/async-openai/pull/387
80
        let usage = dynamo_async_openai::types::CompletionUsage {
81
82
83
84
85
86
87
            completion_tokens: 0,
            prompt_tokens: 0,
            total_tokens: 0,
            completion_tokens_details: None,
            prompt_tokens_details: None,
        };

88
89
        let completion_id = format!("cmpl-{request_id}");

90
        Self {
91
            id: completion_id,
92
93
94
95
            object: "text_completion".to_string(),
            created: now,
            model,
            system_fingerprint: None,
96
            usage,
97
98
99
100
            options,
        }
    }

101
    pub fn update_isl(&mut self, isl: u32) {
102
103
104
        self.usage.prompt_tokens = isl;
    }

Greg Clark's avatar
Greg Clark committed
105
106
107
108
109
110
    pub fn create_logprobs(
        &self,
        tokens: Vec<common::llm_backend::TokenType>,
        token_ids: Vec<TokenIdType>,
        logprobs: Option<common::llm_backend::LogProbs>,
        top_logprobs: Option<common::llm_backend::TopLogprobs>,
111
    ) -> Option<dynamo_async_openai::types::Logprobs> {
Greg Clark's avatar
Greg Clark committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        if !self.options.enable_logprobs || logprobs.is_none() {
            return None;
        }

        let toks = tokens
            .into_iter()
            .zip(token_ids)
            .map(|(token, token_id)| (token.unwrap_or_default(), token_id))
            .collect::<Vec<(String, TokenIdType)>>();
        let tok_lps = toks
            .iter()
            .zip(logprobs.unwrap())
            .map(|(_, lp)| lp as f32)
            .collect::<Vec<f32>>();

        let top_lps = top_logprobs.map_or(vec![], |top_logprobs| {
            toks.iter()
                .zip(tok_lps.iter())
                .zip(top_logprobs.iter())
                .map(|(((t, tid), lp), top_lps)| {
                    let mut found_selected_token = false;
                    let mut converted_top_lps = top_lps
                        .iter()
                        .map(|top_lp| {
                            let top_t = top_lp.token.clone().unwrap_or_default();
                            let top_tid = top_lp.token_id;
                            found_selected_token = found_selected_token || top_tid == *tid;
139
                            dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
140
141
142
143
144
                                token: top_t,
                                logprob: top_lp.logprob as f32,
                                bytes: None,
                            }
                        })
145
                        .collect::<Vec<dynamo_async_openai::types::TopLogprobs>>();
Greg Clark's avatar
Greg Clark committed
146
147
                    if !found_selected_token {
                        // If the selected token is not in the top logprobs, add it
148
                        converted_top_lps.push(dynamo_async_openai::types::TopLogprobs {
Greg Clark's avatar
Greg Clark committed
149
150
151
152
153
154
155
156
157
158
                            token: t.clone(),
                            logprob: *lp,
                            bytes: None,
                        });
                    }
                    serde_json::to_value(converted_top_lps).unwrap()
                })
                .collect()
        });

159
        Some(dynamo_async_openai::types::Logprobs {
Greg Clark's avatar
Greg Clark committed
160
161
162
163
164
165
166
            tokens: toks.iter().map(|(t, _)| t.clone()).collect(),
            token_logprobs: tok_lps.into_iter().map(Some).collect(),
            text_offset: vec![],
            top_logprobs: top_lps,
        })
    }

167
168
    pub fn create_choice(
        &self,
169
        index: u32,
170
        text: Option<String>,
171
172
        finish_reason: Option<dynamo_async_openai::types::CompletionFinishReason>,
        logprobs: Option<dynamo_async_openai::types::Logprobs>,
173
    ) -> NvCreateCompletionResponse {
174
175
        // todo - update for tool calling

176
177
178
        // According to OpenAI spec: when stream_options.include_usage is true,
        // all intermediate chunks should have usage: null
        // The final usage chunk will be sent separately with empty choices
179
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
180
181
            id: self.id.clone(),
            object: self.object.clone(),
182
            created: self.created,
183
184
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
185
            choices: vec![dynamo_async_openai::types::Choice {
186
                text: text.unwrap_or_default(),
187
                index,
188
                finish_reason,
Greg Clark's avatar
Greg Clark committed
189
                logprobs,
190
            }],
191
            usage: None, // Always None for chunks with content/choices
192
            nvext: None, // Will be populated by router layer if needed
193
194
195
        };

        NvCreateCompletionResponse { inner }
196
    }
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

    /// Creates a final usage-only chunk for OpenAI compliance.
    /// This should be sent after the last content chunk when stream_options.include_usage is true.
    ///
    /// # Returns
    /// * A [`NvCreateCompletionResponse`] with empty choices and usage stats.
    pub fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
        let mut usage = self.usage.clone();
        usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);

        let inner = dynamo_async_openai::types::CreateCompletionResponse {
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
            choices: vec![], // Empty choices for usage-only chunk
            usage: Some(usage),
215
            nvext: None, // Will be populated by router layer if needed
216
217
218
219
220
221
222
223
224
        };

        NvCreateCompletionResponse { inner }
    }

    /// Check if usage tracking is enabled
    pub fn is_usage_enabled(&self) -> bool {
        self.options.enable_usage
    }
225
226
}

227
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
228
229
230
    fn choice_from_postprocessor(
        &mut self,
        delta: common::llm_backend::BackendOutput,
231
    ) -> anyhow::Result<NvCreateCompletionResponse> {
232
233
        // aggregate usage
        if self.options.enable_usage {
234
235
236
237
238
239
240
241
242
            // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
            // but this will not be an issue until context lengths exceed 4_294_967_295.
            let token_length: u32 = delta
                .token_ids
                .len()
                .try_into()
                .expect("token_ids length exceeds u32::MAX");

            self.usage.completion_tokens += token_length;
243
244
245
246
247
248
249
250
251
252

            // If backend provides completion_usage with prompt token details,
            // propagate the entire details struct to usage tracking
            if let Some(prompt_details) = delta
                .completion_usage
                .as_ref()
                .and_then(|usage| usage.prompt_tokens_details.as_ref())
            {
                self.usage.prompt_tokens_details = Some(prompt_details.clone());
            }
253
254
        }

Greg Clark's avatar
Greg Clark committed
255
256
257
258
259
260
        let logprobs = self.create_logprobs(
            delta.tokens,
            delta.token_ids,
            delta.log_probs,
            delta.top_logprobs,
        );
261
262

        let finish_reason = delta.finish_reason.map(Into::into);
263
264

        // create choice
265
        let index = delta.index.unwrap_or(0);
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
        let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);

        // Extract worker_id from disaggregated_params and inject into nvext if present
        if let Some(worker_id_json) = delta
            .disaggregated_params
            .as_ref()
            .and_then(|params| params.get("worker_id"))
        {
            use crate::protocols::openai::nvext::{NvExtResponse, WorkerIdInfo};

            let prefill_worker_id = worker_id_json
                .get("prefill_worker_id")
                .and_then(|v| v.as_u64());
            let decode_worker_id = worker_id_json
                .get("decode_worker_id")
                .and_then(|v| v.as_u64());

            let worker_id_info = WorkerIdInfo {
                prefill_worker_id,
                decode_worker_id,
            };

            let nvext_response = NvExtResponse {
                worker_id: Some(worker_id_info),
            };

            if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
                response.inner.nvext = Some(nvext_json);
                tracing::debug!(
                    "Injected worker_id into completions nvext: prefill={:?}, decode={:?}",
                    prefill_worker_id,
                    decode_worker_id
                );
            }
        }

302
        Ok(response)
303
    }
304
305

    fn get_isl(&self) -> Option<u32> {
306
        Some(self.usage.prompt_tokens)
307
    }
308
309
310
311
312
313
314
315

    fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
        DeltaGenerator::create_usage_chunk(self)
    }

    fn is_usage_enabled(&self) -> bool {
        DeltaGenerator::is_usage_enabled(self)
    }
316
}