aggregator.rs 14.5 KB
Newer Older
1
2
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
3
//
4
5
6
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
7
//
8
// http://www.apache.org/licenses/LICENSE-2.0
9
//
10
11
12
13
14
15
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
use std::collections::HashMap;
17
18

use anyhow::Result;
Ryan Olson's avatar
Ryan Olson committed
19
use futures::{Stream, StreamExt};
20

21
use super::NvCreateCompletionResponse;
22
23
use crate::protocols::{
    codec::{Message, SseCodecError},
Paul Hendricks's avatar
Paul Hendricks committed
24
    common::FinishReason,
25
26
27
    convert_sse_stream,
    openai::ParsingOptions,
    Annotated, DataStream,
28
29
30
31
32
33
};

/// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`].
pub struct DeltaAggregator {
    id: String,
    model: String,
34
    created: u32,
35
    usage: Option<dynamo_async_openai::types::CompletionUsage>,
36
    system_fingerprint: Option<String>,
37
    choices: HashMap<u32, DeltaChoice>,
38
39
40
41
    error: Option<String>,
}

struct DeltaChoice {
42
    index: u32,
43
    text: String,
Paul Hendricks's avatar
Paul Hendricks committed
44
    finish_reason: Option<FinishReason>,
45
    logprobs: Option<dynamo_async_openai::types::Logprobs>,
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
}

impl Default for DeltaAggregator {
    fn default() -> Self {
        Self::new()
    }
}

impl DeltaAggregator {
    pub fn new() -> Self {
        Self {
            id: "".to_string(),
            model: "".to_string(),
            created: 0,
            usage: None,
            system_fingerprint: None,
            choices: HashMap::new(),
            error: None,
        }
    }

    /// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
    pub async fn apply(
Ryan Olson's avatar
Ryan Olson committed
69
        stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
70
        parsing_options: ParsingOptions,
71
    ) -> Result<NvCreateCompletionResponse> {
72
        tracing::debug!("Tool Call Parser: {:?}", parsing_options.tool_call_parser); // TODO: remove this once completion has tool call support
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        let aggregator = stream
            .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
                let delta = match delta.ok() {
                    Ok(delta) => delta,
                    Err(error) => {
                        aggregator.error = Some(error);
                        return aggregator;
                    }
                };

                if aggregator.error.is_none() && delta.data.is_some() {
                    // note: we could extract annotations here and add them to the aggregator
                    // to be return as part of the NIM Response Extension
                    // TODO(#14) - Aggregate Annotation

                    // these are cheap to move so we do it every time since we are consuming the delta
                    let delta = delta.data.unwrap();
90
91
                    aggregator.id = delta.inner.id;
                    aggregator.model = delta.inner.model;
92
                    aggregator.created = delta.inner.created;
93
                    if let Some(usage) = delta.inner.usage {
94
95
                        aggregator.usage = Some(usage);
                    }
96
                    if let Some(system_fingerprint) = delta.inner.system_fingerprint {
97
98
99
100
                        aggregator.system_fingerprint = Some(system_fingerprint);
                    }

                    // handle the choices
101
                    for choice in delta.inner.choices {
102
103
104
                        let state_choice =
                            aggregator
                                .choices
105
                                .entry(choice.index)
106
                                .or_insert(DeltaChoice {
107
                                    index: choice.index,
108
109
110
111
112
113
114
                                    text: "".to_string(),
                                    finish_reason: None,
                                    logprobs: choice.logprobs,
                                });

                        state_choice.text.push_str(&choice.text);

115
116
117
118
                        // TODO - handle logprobs

                        // Handle CompletionFinishReason -> FinishReason conversation
                        state_choice.finish_reason = match choice.finish_reason {
119
                            Some(dynamo_async_openai::types::CompletionFinishReason::Stop) => {
120
121
                                Some(FinishReason::Stop)
                            }
122
                            Some(dynamo_async_openai::types::CompletionFinishReason::Length) => {
123
124
                                Some(FinishReason::Length)
                            }
125
126
127
                            Some(
                                dynamo_async_openai::types::CompletionFinishReason::ContentFilter,
                            ) => Some(FinishReason::ContentFilter),
128
129
                            None => None,
                        };
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
                    }
                }
                aggregator
            })
            .await;

        // If we have an error, return it
        let aggregator = if let Some(error) = aggregator.error {
            return Err(anyhow::anyhow!(error));
        } else {
            aggregator
        };

        // extra the aggregated deltas and sort by index
        let mut choices: Vec<_> = aggregator
            .choices
            .into_values()
147
            .map(dynamo_async_openai::types::Choice::from)
148
149
150
151
            .collect();

        choices.sort_by(|a, b| a.index.cmp(&b.index));

152
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
153
            id: aggregator.id,
154
            created: aggregator.created,
155
156
            usage: aggregator.usage,
            model: aggregator.model,
157
            object: "text_completion".to_string(),
158
159
            system_fingerprint: aggregator.system_fingerprint,
            choices,
160
161
162
163
164
        };

        let response = NvCreateCompletionResponse { inner };

        Ok(response)
165
166
167
    }
}

168
impl From<DeltaChoice> for dynamo_async_openai::types::Choice {
169
    fn from(delta: DeltaChoice) -> Self {
170
        let finish_reason = delta.finish_reason.map(Into::into);
171

172
        dynamo_async_openai::types::Choice {
173
            index: delta.index,
174
175
176
177
178
179
180
            text: delta.text,
            finish_reason,
            logprobs: delta.logprobs,
        }
    }
}

181
impl NvCreateCompletionResponse {
182
183
    pub async fn from_sse_stream(
        stream: DataStream<Result<Message, SseCodecError>>,
184
        parsing_options: ParsingOptions,
185
186
    ) -> Result<NvCreateCompletionResponse> {
        let stream = convert_sse_stream::<NvCreateCompletionResponse>(stream);
187
        NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options).await
188
189
190
    }

    pub async fn from_annotated_stream(
Ryan Olson's avatar
Ryan Olson committed
191
        stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
192
        parsing_options: ParsingOptions,
193
    ) -> Result<NvCreateCompletionResponse> {
194
        DeltaAggregator::apply(stream, parsing_options).await
195
196
197
198
199
    }
}

#[cfg(test)]
mod tests {
200
    use std::str::FromStr;
201
202
203

    use futures::stream;

204
    use super::*;
205
    use crate::protocols::openai::completions::NvCreateCompletionResponse;
206

207
    fn create_test_delta(
208
        index: u32,
209
210
        text: &str,
        finish_reason: Option<String>,
211
    ) -> Annotated<NvCreateCompletionResponse> {
212
213
214
215
216
217
218
        // This will silently discard invalid_finish reason values and fall back
        // to None - totally fine since this is test code
        let finish_reason = finish_reason
            .as_deref()
            .and_then(|s| FinishReason::from_str(s).ok())
            .map(Into::into);

219
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
220
221
222
223
224
            id: "test_id".to_string(),
            model: "meta/llama-3.1-8b".to_string(),
            created: 1234567890,
            usage: None,
            system_fingerprint: None,
225
            choices: vec![dynamo_async_openai::types::Choice {
226
                index,
227
228
229
230
231
232
233
234
235
                text: text.to_string(),
                finish_reason,
                logprobs: None,
            }],
            object: "text_completion".to_string(),
        };

        let response = NvCreateCompletionResponse { inner };

236
        Annotated {
237
            data: Some(response),
238
239
240
241
242
243
244
245
246
            id: Some("test_id".to_string()),
            event: None,
            comment: None,
        }
    }

    #[tokio::test]
    async fn test_empty_stream() {
        // Create an empty stream
247
        let stream: DataStream<Annotated<NvCreateCompletionResponse>> = Box::pin(stream::empty());
248
249

        // Call DeltaAggregator::apply
250
        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
251
252
253
254
255
256

        // Check the result
        assert!(result.is_ok());
        let response = result.unwrap();

        // Verify that the response is empty and has default values
257
258
259
260
261
262
        assert_eq!(response.inner.id, "");
        assert_eq!(response.inner.model, "");
        assert_eq!(response.inner.created, 0);
        assert!(response.inner.usage.is_none());
        assert!(response.inner.system_fingerprint.is_none());
        assert_eq!(response.inner.choices.len(), 0);
263
264
265
266
267
268
269
270
271
272
273
    }

    #[tokio::test]
    async fn test_single_delta() {
        // Create a sample delta
        let annotated_delta = create_test_delta(0, "Hello,", Some("length".to_string()));

        // Create a stream
        let stream = Box::pin(stream::iter(vec![annotated_delta]));

        // Call DeltaAggregator::apply
274
        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
275
276
277
278
279
280

        // Check the result
        assert!(result.is_ok());
        let response = result.unwrap();

        // Verify the response fields
281
282
283
284
285
286
287
        assert_eq!(response.inner.id, "test_id");
        assert_eq!(response.inner.model, "meta/llama-3.1-8b");
        assert_eq!(response.inner.created, 1234567890);
        assert!(response.inner.usage.is_none());
        assert!(response.inner.system_fingerprint.is_none());
        assert_eq!(response.inner.choices.len(), 1);
        let choice = &response.inner.choices[0];
288
289
        assert_eq!(choice.index, 0);
        assert_eq!(choice.text, "Hello,".to_string());
290
291
        assert_eq!(
            choice.finish_reason,
292
            Some(dynamo_async_openai::types::CompletionFinishReason::Length)
293
        );
294
295
        assert_eq!(
            choice.finish_reason,
296
            Some(dynamo_async_openai::types::CompletionFinishReason::Length)
297
        );
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        assert!(choice.logprobs.is_none());
    }

    #[tokio::test]
    async fn test_multiple_deltas_same_choice() {
        // Create multiple deltas with the same choice index
        // One will have a MessageRole and no FinishReason,
        // the other will have a FinishReason and no MessageRole
        let annotated_delta1 = create_test_delta(0, "Hello,", None);
        let annotated_delta2 = create_test_delta(0, " world!", Some("stop".to_string()));

        // Create a stream
        let annotated_deltas = vec![annotated_delta1, annotated_delta2];
        let stream = Box::pin(stream::iter(annotated_deltas));

        // Call DeltaAggregator::apply
314
        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
315
316
317
318
319
320

        // Check the result
        assert!(result.is_ok());
        let response = result.unwrap();

        // Verify the response fields
321
322
        assert_eq!(response.inner.choices.len(), 1);
        let choice = &response.inner.choices[0];
323
324
        assert_eq!(choice.index, 0);
        assert_eq!(choice.text, "Hello, world!".to_string());
325
326
        assert_eq!(
            choice.finish_reason,
327
            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
328
        );
329
330
        assert_eq!(
            choice.finish_reason,
331
            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
332
        );
333
334
335
336
337
    }

    #[tokio::test]
    async fn test_multiple_choices() {
        // Create a delta with multiple choices
338
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
339
340
341
342
343
344
            id: "test_id".to_string(),
            model: "meta/llama-3.1-8b".to_string(),
            created: 1234567890,
            usage: None,
            system_fingerprint: None,
            choices: vec![
345
                dynamo_async_openai::types::Choice {
346
347
                    index: 0,
                    text: "Choice 0".to_string(),
348
                    finish_reason: Some(dynamo_async_openai::types::CompletionFinishReason::Stop),
349
350
                    logprobs: None,
                },
351
                dynamo_async_openai::types::Choice {
352
353
                    index: 1,
                    text: "Choice 1".to_string(),
354
                    finish_reason: Some(dynamo_async_openai::types::CompletionFinishReason::Stop),
355
356
357
358
359
360
361
362
                    logprobs: None,
                },
            ],
            object: "text_completion".to_string(),
        };

        let response = NvCreateCompletionResponse { inner };

363
        let annotated_delta = Annotated {
364
            data: Some(response),
365
366
367
368
369
370
371
372
373
            id: Some("test_id".to_string()),
            event: None,
            comment: None,
        };

        // Create a stream
        let stream = Box::pin(stream::iter(vec![annotated_delta]));

        // Call DeltaAggregator::apply
374
        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
375
376
377
378
379
380

        // Check the result
        assert!(result.is_ok());
        let mut response = result.unwrap();

        // Verify the response fields
381
382
383
        assert_eq!(response.inner.choices.len(), 2);
        response.inner.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
        let choice0 = &response.inner.choices[0];
384
385
        assert_eq!(choice0.index, 0);
        assert_eq!(choice0.text, "Choice 0".to_string());
386
387
        assert_eq!(
            choice0.finish_reason,
388
            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
389
        );
390
391
        assert_eq!(
            choice0.finish_reason,
392
            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
393
        );
394

395
        let choice1 = &response.inner.choices[1];
396
397
        assert_eq!(choice1.index, 1);
        assert_eq!(choice1.text, "Choice 1".to_string());
398
399
        assert_eq!(
            choice1.finish_reason,
400
            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
401
        );
402
403
        assert_eq!(
            choice1.finish_reason,
404
            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
405
        );
406
407
    }
}