openai.rs 9.92 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
// SPDX-License-Identifier: Apache-2.0

use anyhow::Result;
use serde::{Deserialize, Serialize};

use super::{
    ContentProvider,
9
    common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
10
};
11
use crate::protocols::openai::common_ext::CommonExtProvider;
12
use crate::types::TokenIdType;
13

14
pub mod audios;
15
pub mod chat_completions;
16
pub mod common_ext;
17
18
pub mod completions;
pub mod embeddings;
19
pub mod images;
20
21
pub mod models;
pub mod nvext;
22
pub mod responses;
23
pub mod tools;
24
pub mod validate;
25
pub mod videos;
26

27
use validate::{
28
29
    BEST_OF_RANGE, FREQUENCY_PENALTY_RANGE, MIN_P_RANGE, N_RANGE, PRESENCE_PENALTY_RANGE,
    TEMPERATURE_RANGE, TOP_P_RANGE, validate_range,
30
};
31
32
33
34
35
36
37
38
39

#[derive(Serialize, Deserialize, Debug)]
pub struct AnnotatedDelta<R> {
    pub delta: R,
    pub id: Option<String>,
    pub event: Option<String>,
    pub comment: Option<String>,
}

40
pub(crate) trait OpenAISamplingOptionsProvider {
41
42
43
44
45
46
47
48
    fn get_temperature(&self) -> Option<f32>;

    fn get_top_p(&self) -> Option<f32>;

    fn get_frequency_penalty(&self) -> Option<f32>;

    fn get_presence_penalty(&self) -> Option<f32>;

49
50
51
52
53
54
    fn get_seed(&self) -> Option<i64>;

    fn get_n(&self) -> Option<u8>;

    fn get_best_of(&self) -> Option<u8>;

55
56
57
    fn nvext(&self) -> Option<&nvext::NvExt>;
}

58
pub(crate) trait OpenAIStopConditionsProvider {
Paul Hendricks's avatar
Paul Hendricks committed
59
    fn get_max_tokens(&self) -> Option<u32>;
60

Paul Hendricks's avatar
Paul Hendricks committed
61
    fn get_min_tokens(&self) -> Option<u32>;
62
63
64
65

    fn get_stop(&self) -> Option<Vec<String>>;

    fn nvext(&self) -> Option<&nvext::NvExt>;
66
67
68
69
70
71
72

    /// Get ignore_eos from CommonExt if the type supports it.
    /// Default returns None for types without CommonExt support.
    fn get_common_ignore_eos(&self) -> Option<bool> {
        None
    }

73
    /// Get the effective ignore_eos value from CommonExt.
74
    fn get_ignore_eos(&self) -> Option<bool> {
75
        self.get_common_ignore_eos()
76
    }
77
78
79
80
81
82

    /// Get max_thinking_tokens from nvext
    /// NOTE: This is currently a passthrough for future thinking budget implementation
    fn get_max_thinking_tokens(&self) -> Option<u32> {
        self.nvext().and_then(|nv| nv.max_thinking_tokens)
    }
83
84
}

85
pub(crate) trait OpenAIOutputOptionsProvider {
Greg Clark's avatar
Greg Clark committed
86
87
88
89
90
91
92
93
94
    fn get_logprobs(&self) -> Option<u32>;

    fn get_prompt_logprobs(&self) -> Option<u32>;

    fn get_skip_special_tokens(&self) -> Option<bool>;

    fn get_formatted_prompt(&self) -> Option<bool>;
}

95
impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvider for T {
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    fn extract_sampling_options(&self) -> Result<common::SamplingOptions> {
        // let result = self.validate();
        // if let Err(e) = result {
        //     return Err(format!("Error validating sampling options: {}", e));
        // }

        let mut temperature = validate_range(self.get_temperature(), &TEMPERATURE_RANGE)
            .map_err(|e| anyhow::anyhow!("Error validating temperature: {}", e))?;
        let mut top_p = validate_range(self.get_top_p(), &TOP_P_RANGE)
            .map_err(|e| anyhow::anyhow!("Error validating top_p: {}", e))?;
        let frequency_penalty =
            validate_range(self.get_frequency_penalty(), &FREQUENCY_PENALTY_RANGE)
                .map_err(|e| anyhow::anyhow!("Error validating frequency_penalty: {}", e))?;
        let presence_penalty = validate_range(self.get_presence_penalty(), &PRESENCE_PENALTY_RANGE)
            .map_err(|e| anyhow::anyhow!("Error validating presence_penalty: {}", e))?;
111
112
        let top_k = CommonExtProvider::get_top_k(self);
        let repetition_penalty = CommonExtProvider::get_repetition_penalty(self);
113
        let include_stop_str_in_output = CommonExtProvider::get_include_stop_str_in_output(self);
114
115
116
117
118
119
120
121
        let seed = self.get_seed();
        let n = validate_range(self.get_n(), &N_RANGE)
            .map_err(|e| anyhow::anyhow!("Error validating n: {}", e))?;
        let best_of = validate_range(self.get_best_of(), &BEST_OF_RANGE)
            .map_err(|e| anyhow::anyhow!("Error validating best_of: {}", e))?;

        let min_p = validate_range(CommonExtProvider::get_min_p(self), &MIN_P_RANGE)
            .map_err(|e| anyhow::anyhow!("Error validating min_p: {}", e))?;
122
123
124
125
126
127
128
129
130

        if let Some(nvext) = self.nvext() {
            let greedy = nvext.greed_sampling.unwrap_or(false);
            if greedy {
                top_p = None;
                temperature = None;
            }
        }

131
132
133
134
135
        let guided_decoding_backend = self.get_guided_decoding_backend();
        let guided_json = self.get_guided_json();
        let guided_regex = self.get_guided_regex();
        let guided_grammar = self.get_guided_grammar();
        let guided_choice = self.get_guided_choice();
136
        let guided_whitespace_pattern = self.get_guided_whitespace_pattern();
137
138

        let guided_decoding = match common::GuidedDecodingOptions::from_optional(
139
            guided_json,
140
141
142
143
            guided_regex,
            guided_choice,
            guided_grammar,
            guided_decoding_backend,
144
            guided_whitespace_pattern,
145
146
147
148
149
150
        ) {
            Ok(options) => options,
            Err(e) => {
                // Handle the validation error (log, return error, etc.)
                tracing::error!("Invalid guided decoding options: {:?}", e);
                return Err(e);
151
            }
152
        };
153

154
        Ok(common::SamplingOptions {
155
156
            n,
            best_of,
157
158
            frequency_penalty,
            presence_penalty,
159
            repetition_penalty,
160
161
            temperature,
            top_p,
162
            top_k,
163
164
            min_p,
            seed,
165
166
            use_beam_search: None,
            length_penalty: None,
167
            guided_decoding,
168
            include_stop_str_in_output,
169
170
171
172
173
174
        })
    }
}

impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
    fn extract_stop_conditions(&self) -> Result<common::StopConditions> {
Paul Hendricks's avatar
Paul Hendricks committed
175
        let max_tokens = self.get_max_tokens();
176
177
        let min_tokens = self.get_min_tokens();
        let stop = self.get_stop();
178
        let max_thinking_tokens = self.get_max_thinking_tokens();
179

180
181
182
183
        if let Some(stop) = &stop
            && stop.len() > 4
        {
            anyhow::bail!("stop conditions must be less than 4")
184
185
        }

186
187
        // Use the trait method to get ignore_eos, which handles precedence
        let ignore_eos = self.get_ignore_eos();
188
189
190

        Ok(common::StopConditions {
            max_tokens,
Paul Hendricks's avatar
Paul Hendricks committed
191
            min_tokens,
192
193
194
            stop,
            stop_token_ids_hidden: None,
            ignore_eos,
195
            max_thinking_tokens,
196
197
198
199
        })
    }
}

Greg Clark's avatar
Greg Clark committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
impl<T: OpenAIOutputOptionsProvider> OutputOptionsProvider for T {
    fn extract_output_options(&self) -> Result<common::OutputOptions> {
        let logprobs = self.get_logprobs();
        let prompt_logprobs = self.get_prompt_logprobs();
        let skip_special_tokens = self.get_skip_special_tokens();
        let formatted_prompt = self.get_formatted_prompt();

        Ok(common::OutputOptions {
            logprobs,
            prompt_logprobs,
            skip_special_tokens,
            formatted_prompt,
        })
    }
}

216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
/// Converts a token string to its UTF-8 byte representation for OpenAI logprobs responses.
/// Returns `None` for empty tokens (unknown/unresolved tokens from the backend).
pub(crate) fn token_to_utf8_bytes(token: &str) -> Option<Vec<u8>> {
    if token.is_empty() {
        None
    } else {
        Some(token.as_bytes().to_vec())
    }
}

/// Converts a list of internal backend `TopLogprob` entries into the OpenAI-compatible
/// `TopLogprobs` format. Ensures the selected token is present in the list.
pub(crate) fn convert_backend_top_logprobs(
    top_lps: &[common::llm_backend::TopLogprob],
    selected_token: &str,
    selected_token_id: TokenIdType,
    selected_logprob: f32,
) -> Vec<dynamo_async_openai::types::TopLogprobs> {
    let mut found_selected = false;
    let mut result: Vec<dynamo_async_openai::types::TopLogprobs> = top_lps
        .iter()
        .map(|top_lp| {
            let tok = top_lp.token.clone().unwrap_or_default();
            found_selected = found_selected || top_lp.token_id == selected_token_id;
            let bytes = top_lp.bytes.clone().or_else(|| token_to_utf8_bytes(&tok));
            dynamo_async_openai::types::TopLogprobs {
                token: tok,
                logprob: top_lp.logprob as f32,
                bytes,
            }
        })
        .collect();

    if !found_selected {
        result.push(dynamo_async_openai::types::TopLogprobs {
            token: selected_token.to_string(),
            logprob: selected_logprob,
            bytes: token_to_utf8_bytes(selected_token),
        });
    }
    result
}

259
260
pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
    Send + 'static
261
262
263
264
265
{
    fn choice_from_postprocessor(
        &mut self,
        response: common::llm_backend::BackendOutput,
    ) -> Result<ResponseType>;
266
267
268

    /// Gets the current prompt token count (Input Sequence Length).
    fn get_isl(&self) -> Option<u32>;
269
270
271
272
273
274

    /// Creates a final usage-only chunk for OpenAI compliance.
    fn create_usage_chunk(&self) -> ResponseType;

    /// Check if usage tracking is enabled.
    fn is_usage_enabled(&self) -> bool;
275

276
277
278
    /// Check if continuous usage tracking is enabled.
    fn is_continuous_usage_enabled(&self) -> bool;

279
280
    /// Get the current usage statistics with properly calculated total_tokens.
    fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage;
281
282
283
284
285

    /// Returns the request tracker if available, for accessing worker timing metrics.
    fn tracker(&self) -> Option<std::sync::Arc<common::timing::RequestTracker>> {
        None
    }
286
}
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct ParsingOptions {
    pub tool_call_parser: Option<String>,

    pub reasoning_parser: Option<String>,
}

impl ParsingOptions {
    pub fn new(tool_call_parser: Option<String>, reasoning_parser: Option<String>) -> Self {
        Self {
            tool_call_parser,
            reasoning_parser,
        }
    }
}