openai.rs 9.87 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
// SPDX-License-Identifier: Apache-2.0

use anyhow::Result;
use serde::{Deserialize, Serialize};

use super::{
    ContentProvider,
9
    common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
10
};
11
use crate::protocols::openai::common_ext::CommonExtProvider;
12
use crate::types::TokenIdType;
13

14
pub mod chat_completions;
15
pub mod common_ext;
16
17
pub mod completions;
pub mod embeddings;
18
pub mod images;
19
20
pub mod models;
pub mod nvext;
21
pub mod responses;
22
pub mod tools;
23
pub mod validate;
24
pub mod videos;
25

26
use validate::{
27
28
    BEST_OF_RANGE, FREQUENCY_PENALTY_RANGE, MIN_P_RANGE, N_RANGE, PRESENCE_PENALTY_RANGE,
    TEMPERATURE_RANGE, TOP_P_RANGE, validate_range,
29
};
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

#[derive(Serialize, Deserialize, Debug)]
pub struct AnnotatedDelta<R> {
    pub delta: R,
    pub id: Option<String>,
    pub event: Option<String>,
    pub comment: Option<String>,
}

trait OpenAISamplingOptionsProvider {
    fn get_temperature(&self) -> Option<f32>;

    fn get_top_p(&self) -> Option<f32>;

    fn get_frequency_penalty(&self) -> Option<f32>;

    fn get_presence_penalty(&self) -> Option<f32>;

48
49
50
51
52
53
    fn get_seed(&self) -> Option<i64>;

    fn get_n(&self) -> Option<u8>;

    fn get_best_of(&self) -> Option<u8>;

54
55
56
57
    fn nvext(&self) -> Option<&nvext::NvExt>;
}

trait OpenAIStopConditionsProvider {
Paul Hendricks's avatar
Paul Hendricks committed
58
    fn get_max_tokens(&self) -> Option<u32>;
59

Paul Hendricks's avatar
Paul Hendricks committed
60
    fn get_min_tokens(&self) -> Option<u32>;
61
62
63
64

    fn get_stop(&self) -> Option<Vec<String>>;

    fn nvext(&self) -> Option<&nvext::NvExt>;
65
66
67
68
69
70
71

    /// Get ignore_eos from CommonExt if the type supports it.
    /// Default returns None for types without CommonExt support.
    fn get_common_ignore_eos(&self) -> Option<bool> {
        None
    }

72
    /// Get the effective ignore_eos value from CommonExt.
73
    fn get_ignore_eos(&self) -> Option<bool> {
74
        self.get_common_ignore_eos()
75
    }
76
77
78
79
80
81

    /// Get max_thinking_tokens from nvext
    /// NOTE: This is currently a passthrough for future thinking budget implementation
    fn get_max_thinking_tokens(&self) -> Option<u32> {
        self.nvext().and_then(|nv| nv.max_thinking_tokens)
    }
82
83
}

Greg Clark's avatar
Greg Clark committed
84
85
86
87
88
89
90
91
92
93
trait OpenAIOutputOptionsProvider {
    fn get_logprobs(&self) -> Option<u32>;

    fn get_prompt_logprobs(&self) -> Option<u32>;

    fn get_skip_special_tokens(&self) -> Option<bool>;

    fn get_formatted_prompt(&self) -> Option<bool>;
}

94
impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvider for T {
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    fn extract_sampling_options(&self) -> Result<common::SamplingOptions> {
        // let result = self.validate();
        // if let Err(e) = result {
        //     return Err(format!("Error validating sampling options: {}", e));
        // }

        let mut temperature = validate_range(self.get_temperature(), &TEMPERATURE_RANGE)
            .map_err(|e| anyhow::anyhow!("Error validating temperature: {}", e))?;
        let mut top_p = validate_range(self.get_top_p(), &TOP_P_RANGE)
            .map_err(|e| anyhow::anyhow!("Error validating top_p: {}", e))?;
        let frequency_penalty =
            validate_range(self.get_frequency_penalty(), &FREQUENCY_PENALTY_RANGE)
                .map_err(|e| anyhow::anyhow!("Error validating frequency_penalty: {}", e))?;
        let presence_penalty = validate_range(self.get_presence_penalty(), &PRESENCE_PENALTY_RANGE)
            .map_err(|e| anyhow::anyhow!("Error validating presence_penalty: {}", e))?;
110
111
        let top_k = CommonExtProvider::get_top_k(self);
        let repetition_penalty = CommonExtProvider::get_repetition_penalty(self);
112
        let include_stop_str_in_output = CommonExtProvider::get_include_stop_str_in_output(self);
113
114
115
116
117
118
119
120
        let seed = self.get_seed();
        let n = validate_range(self.get_n(), &N_RANGE)
            .map_err(|e| anyhow::anyhow!("Error validating n: {}", e))?;
        let best_of = validate_range(self.get_best_of(), &BEST_OF_RANGE)
            .map_err(|e| anyhow::anyhow!("Error validating best_of: {}", e))?;

        let min_p = validate_range(CommonExtProvider::get_min_p(self), &MIN_P_RANGE)
            .map_err(|e| anyhow::anyhow!("Error validating min_p: {}", e))?;
121
122
123
124
125
126
127
128
129

        if let Some(nvext) = self.nvext() {
            let greedy = nvext.greed_sampling.unwrap_or(false);
            if greedy {
                top_p = None;
                temperature = None;
            }
        }

130
131
132
133
134
        let guided_decoding_backend = self.get_guided_decoding_backend();
        let guided_json = self.get_guided_json();
        let guided_regex = self.get_guided_regex();
        let guided_grammar = self.get_guided_grammar();
        let guided_choice = self.get_guided_choice();
135
        let guided_whitespace_pattern = self.get_guided_whitespace_pattern();
136
137

        let guided_decoding = match common::GuidedDecodingOptions::from_optional(
138
            guided_json,
139
140
141
142
            guided_regex,
            guided_choice,
            guided_grammar,
            guided_decoding_backend,
143
            guided_whitespace_pattern,
144
145
146
147
148
149
        ) {
            Ok(options) => options,
            Err(e) => {
                // Handle the validation error (log, return error, etc.)
                tracing::error!("Invalid guided decoding options: {:?}", e);
                return Err(e);
150
            }
151
        };
152

153
        Ok(common::SamplingOptions {
154
155
            n,
            best_of,
156
157
            frequency_penalty,
            presence_penalty,
158
            repetition_penalty,
159
160
            temperature,
            top_p,
161
            top_k,
162
163
            min_p,
            seed,
164
165
            use_beam_search: None,
            length_penalty: None,
166
            guided_decoding,
167
            include_stop_str_in_output,
168
169
170
171
172
173
        })
    }
}

impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
    fn extract_stop_conditions(&self) -> Result<common::StopConditions> {
Paul Hendricks's avatar
Paul Hendricks committed
174
        let max_tokens = self.get_max_tokens();
175
176
        let min_tokens = self.get_min_tokens();
        let stop = self.get_stop();
177
        let max_thinking_tokens = self.get_max_thinking_tokens();
178

179
180
181
182
        if let Some(stop) = &stop
            && stop.len() > 4
        {
            anyhow::bail!("stop conditions must be less than 4")
183
184
        }

185
186
        // Use the trait method to get ignore_eos, which handles precedence
        let ignore_eos = self.get_ignore_eos();
187
188
189

        Ok(common::StopConditions {
            max_tokens,
Paul Hendricks's avatar
Paul Hendricks committed
190
            min_tokens,
191
192
193
            stop,
            stop_token_ids_hidden: None,
            ignore_eos,
194
            max_thinking_tokens,
195
196
197
198
        })
    }
}

Greg Clark's avatar
Greg Clark committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
impl<T: OpenAIOutputOptionsProvider> OutputOptionsProvider for T {
    fn extract_output_options(&self) -> Result<common::OutputOptions> {
        let logprobs = self.get_logprobs();
        let prompt_logprobs = self.get_prompt_logprobs();
        let skip_special_tokens = self.get_skip_special_tokens();
        let formatted_prompt = self.get_formatted_prompt();

        Ok(common::OutputOptions {
            logprobs,
            prompt_logprobs,
            skip_special_tokens,
            formatted_prompt,
        })
    }
}

215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
/// Converts a token string to its UTF-8 byte representation for OpenAI logprobs responses.
/// Returns `None` for empty tokens (unknown/unresolved tokens from the backend).
pub(crate) fn token_to_utf8_bytes(token: &str) -> Option<Vec<u8>> {
    if token.is_empty() {
        None
    } else {
        Some(token.as_bytes().to_vec())
    }
}

/// Converts a list of internal backend `TopLogprob` entries into the OpenAI-compatible
/// `TopLogprobs` format. Ensures the selected token is present in the list.
pub(crate) fn convert_backend_top_logprobs(
    top_lps: &[common::llm_backend::TopLogprob],
    selected_token: &str,
    selected_token_id: TokenIdType,
    selected_logprob: f32,
) -> Vec<dynamo_async_openai::types::TopLogprobs> {
    let mut found_selected = false;
    let mut result: Vec<dynamo_async_openai::types::TopLogprobs> = top_lps
        .iter()
        .map(|top_lp| {
            let tok = top_lp.token.clone().unwrap_or_default();
            found_selected = found_selected || top_lp.token_id == selected_token_id;
            let bytes = top_lp.bytes.clone().or_else(|| token_to_utf8_bytes(&tok));
            dynamo_async_openai::types::TopLogprobs {
                token: tok,
                logprob: top_lp.logprob as f32,
                bytes,
            }
        })
        .collect();

    if !found_selected {
        result.push(dynamo_async_openai::types::TopLogprobs {
            token: selected_token.to_string(),
            logprob: selected_logprob,
            bytes: token_to_utf8_bytes(selected_token),
        });
    }
    result
}

258
259
pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
    Send + 'static
260
261
262
263
264
{
    fn choice_from_postprocessor(
        &mut self,
        response: common::llm_backend::BackendOutput,
    ) -> Result<ResponseType>;
265
266
267

    /// Gets the current prompt token count (Input Sequence Length).
    fn get_isl(&self) -> Option<u32>;
268
269
270
271
272
273

    /// Creates a final usage-only chunk for OpenAI compliance.
    fn create_usage_chunk(&self) -> ResponseType;

    /// Check if usage tracking is enabled.
    fn is_usage_enabled(&self) -> bool;
274

275
276
277
    /// Check if continuous usage tracking is enabled.
    fn is_continuous_usage_enabled(&self) -> bool;

278
279
    /// Get the current usage statistics with properly calculated total_tokens.
    fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage;
280
281
282
283
284

    /// Returns the request tracker if available, for accessing worker timing metrics.
    fn tracker(&self) -> Option<std::sync::Arc<common::timing::RequestTracker>> {
        None
    }
285
}
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301

#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct ParsingOptions {
    pub tool_call_parser: Option<String>,

    pub reasoning_parser: Option<String>,
}

impl ParsingOptions {
    pub fn new(tool_call_parser: Option<String>, reasoning_parser: Option<String>) -> Self {
        Self {
            tool_call_parser,
            reasoning_parser,
        }
    }
}