aggregator.rs 15.7 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
6

use anyhow::Result;
Ryan Olson's avatar
Ryan Olson committed
7
use futures::{Stream, StreamExt};
8

9
use super::NvCreateCompletionResponse;
10
use crate::protocols::{
11
    Annotated, DataStream,
12
    codec::{Message, SseCodecError},
Paul Hendricks's avatar
Paul Hendricks committed
13
    common::FinishReason,
14
15
    convert_sse_stream,
    openai::ParsingOptions,
16
17
18
19
20
21
};

/// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`].
pub struct DeltaAggregator {
    id: String,
    model: String,
22
    created: u32,
23
    usage: Option<dynamo_async_openai::types::CompletionUsage>,
24
    system_fingerprint: Option<String>,
25
    choices: HashMap<u32, DeltaChoice>,
26
    error: Option<String>,
27
    nvext: Option<serde_json::Value>,
28
29
30
}

struct DeltaChoice {
31
    index: u32,
32
    text: String,
Paul Hendricks's avatar
Paul Hendricks committed
33
    finish_reason: Option<FinishReason>,
34
    logprobs: Option<dynamo_async_openai::types::Logprobs>,
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
}

impl Default for DeltaAggregator {
    fn default() -> Self {
        Self::new()
    }
}

impl DeltaAggregator {
    pub fn new() -> Self {
        Self {
            id: "".to_string(),
            model: "".to_string(),
            created: 0,
            usage: None,
            system_fingerprint: None,
            choices: HashMap::new(),
            error: None,
53
            nvext: None,
54
55
56
57
58
        }
    }

    /// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
    pub async fn apply(
Ryan Olson's avatar
Ryan Olson committed
59
        stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
60
        parsing_options: ParsingOptions,
61
    ) -> Result<NvCreateCompletionResponse> {
62
        tracing::debug!("Tool Call Parser: {:?}", parsing_options.tool_call_parser); // TODO: remove this once completion has tool call support
63
64
65
66
67
68
69
70
71
72
        let aggregator = stream
            .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
                let delta = match delta.ok() {
                    Ok(delta) => delta,
                    Err(error) => {
                        aggregator.error = Some(error);
                        return aggregator;
                    }
                };

73
74
75
                if aggregator.error.is_none()
                    && let Some(delta) = delta.data
                {
76
77
78
                    // TODO(#14) - Aggregate Annotation

                    // these are cheap to move so we do it every time since we are consuming the delta
79
80
                    aggregator.id = delta.inner.id;
                    aggregator.model = delta.inner.model;
81
                    aggregator.created = delta.inner.created;
82
                    if let Some(usage) = delta.inner.usage {
83
84
                        aggregator.usage = Some(usage);
                    }
85
                    if let Some(system_fingerprint) = delta.inner.system_fingerprint {
86
87
                        aggregator.system_fingerprint = Some(system_fingerprint);
                    }
88
                    // Aggregate nvext field (take the last non-None value)
89
90
                    if delta.nvext.is_some() {
                        aggregator.nvext = delta.nvext;
91
                    }
92
93

                    // handle the choices
94
                    for choice in delta.inner.choices {
95
96
97
                        let state_choice =
                            aggregator
                                .choices
98
                                .entry(choice.index)
99
                                .or_insert(DeltaChoice {
100
                                    index: choice.index,
101
102
                                    text: "".to_string(),
                                    finish_reason: None,
Greg Clark's avatar
Greg Clark committed
103
                                    logprobs: None,
104
105
106
107
                                });

                        state_choice.text.push_str(&choice.text);

108
109
110
111
                        // TODO - handle logprobs

                        // Handle CompletionFinishReason -> FinishReason conversation
                        state_choice.finish_reason = match choice.finish_reason {
112
                            Some(dynamo_async_openai::types::CompletionFinishReason::Stop) => {
113
114
                                Some(FinishReason::Stop)
                            }
115
                            Some(dynamo_async_openai::types::CompletionFinishReason::Length) => {
116
117
                                Some(FinishReason::Length)
                            }
118
119
120
                            Some(
                                dynamo_async_openai::types::CompletionFinishReason::ContentFilter,
                            ) => Some(FinishReason::ContentFilter),
121
122
                            None => None,
                        };
Greg Clark's avatar
Greg Clark committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

                        // Update logprobs
                        if let Some(logprobs) = &choice.logprobs {
                            let state_lps = state_choice.logprobs.get_or_insert(
                                dynamo_async_openai::types::Logprobs {
                                    tokens: Vec::new(),
                                    token_logprobs: Vec::new(),
                                    top_logprobs: Vec::new(),
                                    text_offset: Vec::new(),
                                },
                            );
                            state_lps.tokens.extend(logprobs.tokens.clone());
                            state_lps
                                .token_logprobs
                                .extend(logprobs.token_logprobs.clone());
                            state_lps.top_logprobs.extend(logprobs.top_logprobs.clone());
                            state_lps.text_offset.extend(logprobs.text_offset.clone());
                        }
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
                    }
                }
                aggregator
            })
            .await;

        // If we have an error, return it
        let aggregator = if let Some(error) = aggregator.error {
            return Err(anyhow::anyhow!(error));
        } else {
            aggregator
        };

        // extra the aggregated deltas and sort by index
        let mut choices: Vec<_> = aggregator
            .choices
            .into_values()
158
            .map(dynamo_async_openai::types::Choice::from)
159
160
161
162
            .collect();

        choices.sort_by(|a, b| a.index.cmp(&b.index));

163
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
164
            id: aggregator.id,
165
            created: aggregator.created,
166
167
            usage: aggregator.usage,
            model: aggregator.model,
168
            object: "text_completion".to_string(),
169
170
            system_fingerprint: aggregator.system_fingerprint,
            choices,
171
172
        };

173
174
175
176
        let response = NvCreateCompletionResponse {
            inner,
            nvext: aggregator.nvext,
        };
177
178

        Ok(response)
179
180
181
    }
}

182
impl From<DeltaChoice> for dynamo_async_openai::types::Choice {
183
    fn from(delta: DeltaChoice) -> Self {
184
        let finish_reason = delta.finish_reason.map(Into::into);
185

186
        dynamo_async_openai::types::Choice {
187
            index: delta.index,
188
189
190
191
192
193
194
            text: delta.text,
            finish_reason,
            logprobs: delta.logprobs,
        }
    }
}

195
impl NvCreateCompletionResponse {
196
197
    pub async fn from_sse_stream(
        stream: DataStream<Result<Message, SseCodecError>>,
198
        parsing_options: ParsingOptions,
199
200
    ) -> Result<NvCreateCompletionResponse> {
        let stream = convert_sse_stream::<NvCreateCompletionResponse>(stream);
201
        NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options).await
202
203
204
    }

    pub async fn from_annotated_stream(
Ryan Olson's avatar
Ryan Olson committed
205
        stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
206
        parsing_options: ParsingOptions,
207
    ) -> Result<NvCreateCompletionResponse> {
208
        DeltaAggregator::apply(stream, parsing_options).await
209
210
211
212
213
    }
}

#[cfg(test)]
mod tests {
214
    use std::str::FromStr;
215
216
217

    use futures::stream;

218
    use super::*;
219
    use crate::protocols::openai::completions::NvCreateCompletionResponse;
220

221
    fn create_test_delta(
222
        index: u32,
223
224
        text: &str,
        finish_reason: Option<String>,
Greg Clark's avatar
Greg Clark committed
225
        logprob: Option<f32>,
226
    ) -> Annotated<NvCreateCompletionResponse> {
227
228
229
230
231
232
233
        // This will silently discard invalid_finish reason values and fall back
        // to None - totally fine since this is test code
        let finish_reason = finish_reason
            .as_deref()
            .and_then(|s| FinishReason::from_str(s).ok())
            .map(Into::into);

Greg Clark's avatar
Greg Clark committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        let logprobs = logprob.map(|lp| dynamo_async_openai::types::Logprobs {
            tokens: vec![text.to_string()],
            token_logprobs: vec![Some(lp)],
            top_logprobs: vec![
                serde_json::to_value(dynamo_async_openai::types::TopLogprobs {
                    token: text.to_string(),
                    logprob: lp,
                    bytes: None,
                })
                .unwrap(),
            ],
            text_offset: vec![0],
        });

248
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
249
250
251
252
253
            id: "test_id".to_string(),
            model: "meta/llama-3.1-8b".to_string(),
            created: 1234567890,
            usage: None,
            system_fingerprint: None,
254
            choices: vec![dynamo_async_openai::types::Choice {
255
                index,
256
257
                text: text.to_string(),
                finish_reason,
Greg Clark's avatar
Greg Clark committed
258
                logprobs,
259
260
261
262
            }],
            object: "text_completion".to_string(),
        };

263
        let response = NvCreateCompletionResponse { inner, nvext: None };
264

265
        Annotated {
266
            data: Some(response),
267
268
269
            id: Some("test_id".to_string()),
            event: None,
            comment: None,
270
            error: None,
271
272
273
274
275
276
        }
    }

    #[tokio::test]
    async fn test_empty_stream() {
        // Create an empty stream
277
        let stream: DataStream<Annotated<NvCreateCompletionResponse>> = Box::pin(stream::empty());
278
279

        // Call DeltaAggregator::apply
280
        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
281
282
283
284
285
286

        // Check the result
        assert!(result.is_ok());
        let response = result.unwrap();

        // Verify that the response is empty and has default values
287
288
289
290
291
292
        assert_eq!(response.inner.id, "");
        assert_eq!(response.inner.model, "");
        assert_eq!(response.inner.created, 0);
        assert!(response.inner.usage.is_none());
        assert!(response.inner.system_fingerprint.is_none());
        assert_eq!(response.inner.choices.len(), 0);
293
294
295
296
297
    }

    #[tokio::test]
    async fn test_single_delta() {
        // Create a sample delta
Greg Clark's avatar
Greg Clark committed
298
        let annotated_delta = create_test_delta(0, "Hello,", Some("length".to_string()), None);
299
300
301
302
303

        // Create a stream
        let stream = Box::pin(stream::iter(vec![annotated_delta]));

        // Call DeltaAggregator::apply
304
        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
305
306
307
308
309
310

        // Check the result
        assert!(result.is_ok());
        let response = result.unwrap();

        // Verify the response fields
311
312
313
314
315
316
317
        assert_eq!(response.inner.id, "test_id");
        assert_eq!(response.inner.model, "meta/llama-3.1-8b");
        assert_eq!(response.inner.created, 1234567890);
        assert!(response.inner.usage.is_none());
        assert!(response.inner.system_fingerprint.is_none());
        assert_eq!(response.inner.choices.len(), 1);
        let choice = &response.inner.choices[0];
318
319
        assert_eq!(choice.index, 0);
        assert_eq!(choice.text, "Hello,".to_string());
320
321
        assert_eq!(
            choice.finish_reason,
322
            Some(dynamo_async_openai::types::CompletionFinishReason::Length)
323
        );
324
325
        assert_eq!(
            choice.finish_reason,
326
            Some(dynamo_async_openai::types::CompletionFinishReason::Length)
327
        );
328
329
330
331
332
333
334
335
        assert!(choice.logprobs.is_none());
    }

    #[tokio::test]
    async fn test_multiple_deltas_same_choice() {
        // Create multiple deltas with the same choice index
        // One will have a MessageRole and no FinishReason,
        // the other will have a FinishReason and no MessageRole
Greg Clark's avatar
Greg Clark committed
336
337
338
        let annotated_delta1 = create_test_delta(0, "Hello,", None, Some(-0.1));
        let annotated_delta2 =
            create_test_delta(0, " world!", Some("stop".to_string()), Some(-0.2));
339
340
341
342
343
344

        // Create a stream
        let annotated_deltas = vec![annotated_delta1, annotated_delta2];
        let stream = Box::pin(stream::iter(annotated_deltas));

        // Call DeltaAggregator::apply
345
        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
346
347
348
349
350
351

        // Check the result
        assert!(result.is_ok());
        let response = result.unwrap();

        // Verify the response fields
352
353
        assert_eq!(response.inner.choices.len(), 1);
        let choice = &response.inner.choices[0];
354
355
        assert_eq!(choice.index, 0);
        assert_eq!(choice.text, "Hello, world!".to_string());
356
357
        assert_eq!(
            choice.finish_reason,
358
            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
359
        );
Greg Clark's avatar
Greg Clark committed
360
        assert_eq!(choice.logprobs.as_ref().unwrap().tokens.len(), 2);
361
        assert_eq!(
Greg Clark's avatar
Greg Clark committed
362
363
            choice.logprobs.as_ref().unwrap().token_logprobs,
            vec![Some(-0.1), Some(-0.2)]
364
        );
365
366
367
368
369
    }

    #[tokio::test]
    async fn test_multiple_choices() {
        // Create a delta with multiple choices
370
        let inner = dynamo_async_openai::types::CreateCompletionResponse {
371
372
373
374
375
376
            id: "test_id".to_string(),
            model: "meta/llama-3.1-8b".to_string(),
            created: 1234567890,
            usage: None,
            system_fingerprint: None,
            choices: vec![
377
                dynamo_async_openai::types::Choice {
378
379
                    index: 0,
                    text: "Choice 0".to_string(),
380
                    finish_reason: Some(dynamo_async_openai::types::CompletionFinishReason::Stop),
381
382
                    logprobs: None,
                },
383
                dynamo_async_openai::types::Choice {
384
385
                    index: 1,
                    text: "Choice 1".to_string(),
386
                    finish_reason: Some(dynamo_async_openai::types::CompletionFinishReason::Stop),
387
388
389
390
391
392
                    logprobs: None,
                },
            ],
            object: "text_completion".to_string(),
        };

393
        let response = NvCreateCompletionResponse { inner, nvext: None };
394

395
        let annotated_delta = Annotated {
396
            data: Some(response),
397
398
399
            id: Some("test_id".to_string()),
            event: None,
            comment: None,
400
            error: None,
401
402
403
404
405
406
        };

        // Create a stream
        let stream = Box::pin(stream::iter(vec![annotated_delta]));

        // Call DeltaAggregator::apply
407
        let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
408
409
410
411
412
413

        // Check the result
        assert!(result.is_ok());
        let mut response = result.unwrap();

        // Verify the response fields
414
415
416
        assert_eq!(response.inner.choices.len(), 2);
        response.inner.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
        let choice0 = &response.inner.choices[0];
417
418
        assert_eq!(choice0.index, 0);
        assert_eq!(choice0.text, "Choice 0".to_string());
419
420
        assert_eq!(
            choice0.finish_reason,
421
            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
422
        );
423
424
        assert_eq!(
            choice0.finish_reason,
425
            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
426
        );
427

428
        let choice1 = &response.inner.choices[1];
429
430
        assert_eq!(choice1.index, 1);
        assert_eq!(choice1.text, "Choice 1".to_string());
431
432
        assert_eq!(
            choice1.finish_reason,
433
            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
434
        );
435
436
        assert_eq!(
            choice1.finish_reason,
437
            Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
438
        );
439
440
    }
}