aggregator.rs 13.8 KB
Newer Older
1
2
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
3
//
4
5
6
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
7
//
8
// http://www.apache.org/licenses/LICENSE-2.0
9
//
10
11
12
13
14
15
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
use std::collections::HashMap;
17
18
19
20

use anyhow::Result;
use futures::StreamExt;

21
use super::NvCreateCompletionResponse;
22
23
use crate::protocols::{
    codec::{Message, SseCodecError},
Paul Hendricks's avatar
Paul Hendricks committed
24
    common::FinishReason,
25
26
27
28
29
30
31
    convert_sse_stream, Annotated, DataStream,
};

/// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`].
pub struct DeltaAggregator {
    id: String,
    model: String,
32
    created: u32,
33
    usage: Option<async_openai::types::CompletionUsage>,
34
    system_fingerprint: Option<String>,
35
    choices: HashMap<u32, DeltaChoice>,
36
37
38
39
    error: Option<String>,
}

struct DeltaChoice {
40
    index: u32,
41
    text: String,
Paul Hendricks's avatar
Paul Hendricks committed
42
    finish_reason: Option<FinishReason>,
43
    logprobs: Option<async_openai::types::Logprobs>,
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
}

impl Default for DeltaAggregator {
    fn default() -> Self {
        Self::new()
    }
}

impl DeltaAggregator {
    pub fn new() -> Self {
        Self {
            id: "".to_string(),
            model: "".to_string(),
            created: 0,
            usage: None,
            system_fingerprint: None,
            choices: HashMap::new(),
            error: None,
        }
    }

    /// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
    pub async fn apply(
67
68
        stream: DataStream<Annotated<NvCreateCompletionResponse>>,
    ) -> Result<NvCreateCompletionResponse> {
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        let aggregator = stream
            .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
                let delta = match delta.ok() {
                    Ok(delta) => delta,
                    Err(error) => {
                        aggregator.error = Some(error);
                        return aggregator;
                    }
                };

                if aggregator.error.is_none() && delta.data.is_some() {
                    // note: we could extract annotations here and add them to the aggregator
                    // to be return as part of the NIM Response Extension
                    // TODO(#14) - Aggregate Annotation

                    // these are cheap to move so we do it every time since we are consuming the delta
                    let delta = delta.data.unwrap();
86
87
                    aggregator.id = delta.inner.id;
                    aggregator.model = delta.inner.model;
88
                    aggregator.created = delta.inner.created;
89
                    if let Some(usage) = delta.inner.usage {
90
91
                        aggregator.usage = Some(usage);
                    }
92
                    if let Some(system_fingerprint) = delta.inner.system_fingerprint {
93
94
95
96
                        aggregator.system_fingerprint = Some(system_fingerprint);
                    }

                    // handle the choices
97
                    for choice in delta.inner.choices {
98
99
100
                        let state_choice =
                            aggregator
                                .choices
101
                                .entry(choice.index)
102
                                .or_insert(DeltaChoice {
103
                                    index: choice.index,
104
105
106
107
108
109
110
                                    text: "".to_string(),
                                    finish_reason: None,
                                    logprobs: choice.logprobs,
                                });

                        state_choice.text.push_str(&choice.text);

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
                        // TODO - handle logprobs

                        // Handle CompletionFinishReason -> FinishReason conversation
                        state_choice.finish_reason = match choice.finish_reason {
                            Some(async_openai::types::CompletionFinishReason::Stop) => {
                                Some(FinishReason::Stop)
                            }
                            Some(async_openai::types::CompletionFinishReason::Length) => {
                                Some(FinishReason::Length)
                            }
                            Some(async_openai::types::CompletionFinishReason::ContentFilter) => {
                                Some(FinishReason::ContentFilter)
                            }
                            None => None,
                        };
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
                    }
                }
                aggregator
            })
            .await;

        // If we have an error, return it
        let aggregator = if let Some(error) = aggregator.error {
            return Err(anyhow::anyhow!(error));
        } else {
            aggregator
        };

        // extra the aggregated deltas and sort by index
        let mut choices: Vec<_> = aggregator
            .choices
            .into_values()
143
            .map(async_openai::types::Choice::from)
144
145
146
147
            .collect();

        choices.sort_by(|a, b| a.index.cmp(&b.index));

148
        let inner = async_openai::types::CreateCompletionResponse {
149
            id: aggregator.id,
150
            created: aggregator.created,
151
152
            usage: aggregator.usage,
            model: aggregator.model,
153
            object: "text_completion".to_string(),
154
155
            system_fingerprint: aggregator.system_fingerprint,
            choices,
156
157
158
159
160
        };

        let response = NvCreateCompletionResponse { inner };

        Ok(response)
161
162
163
    }
}

164
impl From<DeltaChoice> for async_openai::types::Choice {
165
    fn from(delta: DeltaChoice) -> Self {
166
        let finish_reason = delta.finish_reason.map(Into::into);
167

168
        async_openai::types::Choice {
169
            index: delta.index,
170
171
172
173
174
175
176
            text: delta.text,
            finish_reason,
            logprobs: delta.logprobs,
        }
    }
}

177
impl NvCreateCompletionResponse {
178
179
    pub async fn from_sse_stream(
        stream: DataStream<Result<Message, SseCodecError>>,
180
181
182
    ) -> Result<NvCreateCompletionResponse> {
        let stream = convert_sse_stream::<NvCreateCompletionResponse>(stream);
        NvCreateCompletionResponse::from_annotated_stream(stream).await
183
184
185
    }

    pub async fn from_annotated_stream(
186
187
        stream: DataStream<Annotated<NvCreateCompletionResponse>>,
    ) -> Result<NvCreateCompletionResponse> {
188
189
190
191
192
193
        DeltaAggregator::apply(stream).await
    }
}

#[cfg(test)]
mod tests {
194
    use std::str::FromStr;
195
196
197

    use futures::stream;

198
    use super::*;
199
    use crate::protocols::openai::completions::NvCreateCompletionResponse;
200

201
    fn create_test_delta(
202
        index: u32,
203
204
        text: &str,
        finish_reason: Option<String>,
205
    ) -> Annotated<NvCreateCompletionResponse> {
206
207
208
209
210
211
212
        // This will silently discard invalid_finish reason values and fall back
        // to None - totally fine since this is test code
        let finish_reason = finish_reason
            .as_deref()
            .and_then(|s| FinishReason::from_str(s).ok())
            .map(Into::into);

213
214
215
216
217
218
219
        let inner = async_openai::types::CreateCompletionResponse {
            id: "test_id".to_string(),
            model: "meta/llama-3.1-8b".to_string(),
            created: 1234567890,
            usage: None,
            system_fingerprint: None,
            choices: vec![async_openai::types::Choice {
220
                index,
221
222
223
224
225
226
227
228
229
                text: text.to_string(),
                finish_reason,
                logprobs: None,
            }],
            object: "text_completion".to_string(),
        };

        let response = NvCreateCompletionResponse { inner };

230
        Annotated {
231
            data: Some(response),
232
233
234
235
236
237
238
239
240
            id: Some("test_id".to_string()),
            event: None,
            comment: None,
        }
    }

    #[tokio::test]
    async fn test_empty_stream() {
        // Create an empty stream
241
        let stream: DataStream<Annotated<NvCreateCompletionResponse>> = Box::pin(stream::empty());
242
243
244
245
246
247
248
249
250

        // Call DeltaAggregator::apply
        let result = DeltaAggregator::apply(stream).await;

        // Check the result
        assert!(result.is_ok());
        let response = result.unwrap();

        // Verify that the response is empty and has default values
251
252
253
254
255
256
        assert_eq!(response.inner.id, "");
        assert_eq!(response.inner.model, "");
        assert_eq!(response.inner.created, 0);
        assert!(response.inner.usage.is_none());
        assert!(response.inner.system_fingerprint.is_none());
        assert_eq!(response.inner.choices.len(), 0);
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    }

    #[tokio::test]
    async fn test_single_delta() {
        // Create a sample delta
        let annotated_delta = create_test_delta(0, "Hello,", Some("length".to_string()));

        // Create a stream
        let stream = Box::pin(stream::iter(vec![annotated_delta]));

        // Call DeltaAggregator::apply
        let result = DeltaAggregator::apply(stream).await;

        // Check the result
        assert!(result.is_ok());
        let response = result.unwrap();

        // Verify the response fields
275
276
277
278
279
280
281
        assert_eq!(response.inner.id, "test_id");
        assert_eq!(response.inner.model, "meta/llama-3.1-8b");
        assert_eq!(response.inner.created, 1234567890);
        assert!(response.inner.usage.is_none());
        assert!(response.inner.system_fingerprint.is_none());
        assert_eq!(response.inner.choices.len(), 1);
        let choice = &response.inner.choices[0];
282
283
        assert_eq!(choice.index, 0);
        assert_eq!(choice.text, "Hello,".to_string());
284
285
286
287
        assert_eq!(
            choice.finish_reason,
            Some(async_openai::types::CompletionFinishReason::Length)
        );
288
289
290
291
        assert_eq!(
            choice.finish_reason,
            Some(async_openai::types::CompletionFinishReason::Length)
        );
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        assert!(choice.logprobs.is_none());
    }

    #[tokio::test]
    async fn test_multiple_deltas_same_choice() {
        // Create multiple deltas with the same choice index
        // One will have a MessageRole and no FinishReason,
        // the other will have a FinishReason and no MessageRole
        let annotated_delta1 = create_test_delta(0, "Hello,", None);
        let annotated_delta2 = create_test_delta(0, " world!", Some("stop".to_string()));

        // Create a stream
        let annotated_deltas = vec![annotated_delta1, annotated_delta2];
        let stream = Box::pin(stream::iter(annotated_deltas));

        // Call DeltaAggregator::apply
        let result = DeltaAggregator::apply(stream).await;

        // Check the result
        assert!(result.is_ok());
        let response = result.unwrap();

        // Verify the response fields
315
316
        assert_eq!(response.inner.choices.len(), 1);
        let choice = &response.inner.choices[0];
317
318
        assert_eq!(choice.index, 0);
        assert_eq!(choice.text, "Hello, world!".to_string());
319
320
321
322
        assert_eq!(
            choice.finish_reason,
            Some(async_openai::types::CompletionFinishReason::Stop)
        );
323
324
325
326
        assert_eq!(
            choice.finish_reason,
            Some(async_openai::types::CompletionFinishReason::Stop)
        );
327
328
329
330
331
    }

    #[tokio::test]
    async fn test_multiple_choices() {
        // Create a delta with multiple choices
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
        let inner = async_openai::types::CreateCompletionResponse {
            id: "test_id".to_string(),
            model: "meta/llama-3.1-8b".to_string(),
            created: 1234567890,
            usage: None,
            system_fingerprint: None,
            choices: vec![
                async_openai::types::Choice {
                    index: 0,
                    text: "Choice 0".to_string(),
                    finish_reason: Some(async_openai::types::CompletionFinishReason::Stop),
                    logprobs: None,
                },
                async_openai::types::Choice {
                    index: 1,
                    text: "Choice 1".to_string(),
                    finish_reason: Some(async_openai::types::CompletionFinishReason::Stop),
                    logprobs: None,
                },
            ],
            object: "text_completion".to_string(),
        };

        let response = NvCreateCompletionResponse { inner };

357
        let annotated_delta = Annotated {
358
            data: Some(response),
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
            id: Some("test_id".to_string()),
            event: None,
            comment: None,
        };

        // Create a stream
        let stream = Box::pin(stream::iter(vec![annotated_delta]));

        // Call DeltaAggregator::apply
        let result = DeltaAggregator::apply(stream).await;

        // Check the result
        assert!(result.is_ok());
        let mut response = result.unwrap();

        // Verify the response fields
375
376
377
        assert_eq!(response.inner.choices.len(), 2);
        response.inner.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
        let choice0 = &response.inner.choices[0];
378
379
        assert_eq!(choice0.index, 0);
        assert_eq!(choice0.text, "Choice 0".to_string());
380
381
382
383
        assert_eq!(
            choice0.finish_reason,
            Some(async_openai::types::CompletionFinishReason::Stop)
        );
384
385
386
387
        assert_eq!(
            choice0.finish_reason,
            Some(async_openai::types::CompletionFinishReason::Stop)
        );
388

389
        let choice1 = &response.inner.choices[1];
390
391
        assert_eq!(choice1.index, 1);
        assert_eq!(choice1.text, "Choice 1".to_string());
392
393
394
395
        assert_eq!(
            choice1.finish_reason,
            Some(async_openai::types::CompletionFinishReason::Stop)
        );
396
397
398
399
        assert_eq!(
            choice1.finish_reason,
            Some(async_openai::types::CompletionFinishReason::Stop)
        );
400
401
    }
}