llm_backend.rs 9.15 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
// SPDX-License-Identifier: Apache-2.0

use serde::{Deserialize, Serialize};

6
pub use super::FinishReason;
7
pub use super::preprocessor::PreprocessedRequest;
8
use crate::protocols::TokenIdType;
9
10
use dynamo_protocols::types::CompletionUsage;
use dynamo_protocols::types::StopReason;
11
use dynamo_runtime::error::DynamoError;
12
use dynamo_runtime::protocols::maybe_error::MaybeError;
13
14
15
16

pub type TokenType = Option<String>;
pub type LogProbs = Vec<f64>;

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
/// Output type discriminator for different modalities
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum OutputType {
    #[default]
    Text,
    Image,
    Video,
    Audio,
}

/// Image URL data for responses
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ImageUrlData {
    pub url: String,
}

/// Video URL data for responses
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct VideoUrlData {
    pub url: String,
}

/// Audio URL data for responses
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct AudioUrlData {
    pub url: String,
}

/// Content part for multimodal outputs (internal representation)
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
    Text { text: String },
    ImageUrl { image_url: ImageUrlData },
    VideoUrl { video_url: VideoUrlData },
    AudioUrl { audio_url: AudioUrlData },
}

Greg Clark's avatar
Greg Clark committed
56
57
58
59
60
61
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct TopLogprob {
    pub rank: u32,
    pub token_id: TokenIdType,
    pub token: TokenType,
    pub logprob: f64,
62
63
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub bytes: Option<Vec<u8>>,
Greg Clark's avatar
Greg Clark committed
64
65
66
}
pub type TopLogprobs = Vec<Vec<TopLogprob>>; // num_tokens x top_logprobs

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct BackendOutput {
    /// New token_ids generated from the LLM Engine
    pub token_ids: Vec<TokenIdType>,

    /// Unlike [`LLMEngineOutput::tokens`], this is a vector of tokens, not an optional.
    /// The size of this vector should be the same as the size of `token_ids`.
    pub tokens: Vec<TokenType>,

    /// Decoded text from the list tokens.
    pub text: Option<String>,

    /// Optional cumulative log probabilities
    pub cum_log_probs: Option<f64>,

    /// Optional log probabilities
    pub log_probs: Option<LogProbs>,

Greg Clark's avatar
Greg Clark committed
85
86
    pub top_logprobs: Option<TopLogprobs>,

87
88
89
    // TODO: Enrich this with more information as can apply our first-level postprocessing
    // logic and return more detailed information
    pub finish_reason: Option<FinishReason>,
90
91
92
93
94
95

    /// The stop string or token that triggered the stop condition.
    /// This is set when finish_reason is Stop and identifies what triggered it.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub stop_reason: Option<StopReason>,

96
97
    // Model Deployment Card checksum
    //pub mdcsum: String,
98
99
100

    // Index field for batch requests to match OpenAI format
    pub index: Option<u32>,
101
102
103
104

    // Token usage information
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub completion_usage: Option<CompletionUsage>,
105
106
107
108

    /// Disaggregated execution parameters (for prefill/decode separation)
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub disaggregated_params: Option<serde_json::Value>,
109
110
111
112
113
114
115
116
117
118
}

/// The LLM engine and backnd with manage it's own state, specifically translating how a
/// given request/slot is managed on that particular backend.
///
/// For nvLLM's purpose, it has a single tracable request_id as part of it's context that
/// has propaged through the service pipeline to the backend.
///
/// This is the minimal raw output from the LLM engine. The Backend may then apply multiple
/// levels of post-processing before the BackendOutput is returns
119
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
120
121
122
123
124
125
126
127
128
129
130
pub struct LLMEngineOutput {
    // new token_ids
    pub token_ids: Vec<TokenIdType>,

    /// If the LLM Engine performs the detokenization, then this will have a Some of the detokenized
    /// text/tokens. If this value is None, then the Backend is responsible for detokenization.
    pub tokens: Option<Vec<TokenType>>,

    // decoded text -
    pub text: Option<String>,

131
132
133
134
135
136
137
138
    /// Output type discriminator (text, image, video, audio)
    #[serde(default)]
    pub output_type: OutputType,

    /// Multimodal content parts (for non-text outputs)
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub content_parts: Option<Vec<ContentPart>>,

139
140
141
142
143
144
    /// cumulative log probabilities
    pub cum_log_probs: Option<f64>,

    /// Optional log probabilities
    pub log_probs: Option<LogProbs>,

Greg Clark's avatar
Greg Clark committed
145
146
    pub top_logprobs: Option<TopLogprobs>,

147
148
149
    // TODO: Enrich this with more information as can apply our first-level postprocessing
    // logic and return more detailed information
    pub finish_reason: Option<FinishReason>,
150

151
152
153
154
155
    /// The stop string or token that triggered the stop condition.
    /// This is set when finish_reason is Stop and identifies what triggered it.
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub stop_reason: Option<StopReason>,

156
157
    // Index field for batch requests to match OpenAI format
    pub index: Option<u32>,
158

159
160
161
162
    /// Disaggregated execution parameters (for prefill/decode separation)
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub disaggregated_params: Option<serde_json::Value>,

163
164
165
    /// Additional arguments for extensibility
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub extra_args: Option<serde_json::Value>,
166
167
168
169

    // Token usage information
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub completion_usage: Option<CompletionUsage>,
170
171
172
173
174
175
176
177
}

impl LLMEngineOutput {
    pub fn cancelled() -> Self {
        LLMEngineOutput {
            token_ids: vec![],
            tokens: None,
            text: None,
178
179
            output_type: OutputType::default(),
            content_parts: None,
180
181
            cum_log_probs: None,
            log_probs: None,
Greg Clark's avatar
Greg Clark committed
182
            top_logprobs: None,
183
            finish_reason: Some(FinishReason::Cancelled),
184
            stop_reason: None,
185
            index: None,
186
            disaggregated_params: None,
187
            extra_args: None,
188
            completion_usage: None,
189
190
191
192
193
194
195
196
        }
    }

    pub fn stop() -> Self {
        LLMEngineOutput {
            token_ids: vec![],
            tokens: None,
            text: None,
197
198
            output_type: OutputType::default(),
            content_parts: None,
199
200
201
            cum_log_probs: None,
            log_probs: None,
            finish_reason: Some(FinishReason::Stop),
202
            stop_reason: None,
Greg Clark's avatar
Greg Clark committed
203
            top_logprobs: None,
204
            index: None,
205
            disaggregated_params: None,
206
            extra_args: None,
207
            completion_usage: None,
208
209
210
211
212
213
214
215
        }
    }

    pub fn length() -> Self {
        LLMEngineOutput {
            token_ids: vec![],
            tokens: None,
            text: None,
216
217
            output_type: OutputType::default(),
            content_parts: None,
218
219
            cum_log_probs: None,
            log_probs: None,
Greg Clark's avatar
Greg Clark committed
220
            top_logprobs: None,
221
            finish_reason: Some(FinishReason::Length),
222
            stop_reason: None,
223
            index: None,
224
            disaggregated_params: None,
225
            extra_args: None,
226
            completion_usage: None,
227
228
229
230
231
232
233
234
        }
    }

    pub fn error(err_msg: String) -> Self {
        LLMEngineOutput {
            token_ids: vec![],
            tokens: None,
            text: None,
235
236
            output_type: OutputType::default(),
            content_parts: None,
237
238
            cum_log_probs: None,
            log_probs: None,
Greg Clark's avatar
Greg Clark committed
239
            top_logprobs: None,
240
            finish_reason: Some(FinishReason::Error(err_msg)),
241
            stop_reason: None,
242
            index: None,
243
            disaggregated_params: None,
244
            extra_args: None,
245
            completion_usage: None,
246
247
248
        }
    }
}
249

250
impl MaybeError for LLMEngineOutput {
251
252
    fn from_err(err: impl std::error::Error + 'static) -> Self {
        LLMEngineOutput::error(err.to_string())
253
254
    }

255
    fn err(&self) -> Option<DynamoError> {
256
        if let Some(FinishReason::Error(err_msg)) = &self.finish_reason {
257
            Some(DynamoError::msg(err_msg.clone()))
258
259
260
261
262
263
        } else {
            None
        }
    }
}

264
265
266
267
268
269
270
271
272
273
/// Raw output from embedding engines containing embedding vectors
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct EmbeddingsEngineOutput {
    /// Generated embedding vectors (one per input text)
    pub embeddings: Vec<Vec<f64>>,

    /// Token usage information
    pub prompt_tokens: u32,
    pub total_tokens: u32,
}
274
275
276
277
278
279
280
281
282
283
284
285
286

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_maybe_error() {
        let output = LLMEngineOutput::stop();
        assert!(output.err().is_none());
        assert!(output.is_ok());
        assert!(!output.is_err());

        let output = LLMEngineOutput::error("Test error".to_string());
287
        assert!(format!("{}", output.err().unwrap()).contains("Test error"));
288
289
290
291
        assert!(!output.is_ok());
        assert!(output.is_err());
    }
}