aggregators.rs 7.98 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
6
7
use dynamo_async_openai::types::{
    ChatChoiceStream, ChatCompletionMessageContent, ChatCompletionStreamResponseDelta,
    CreateChatCompletionStreamResponse, Role,
};
Neelay Shah's avatar
Neelay Shah committed
8
use dynamo_llm::protocols::{
9
    Annotated, ContentProvider, DataStream,
10
    codec::{Message, SseCodecError, create_message_stream},
11
    openai::{
12
        ParsingOptions,
13
14
15
16
        chat_completions::{
            NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse,
            aggregator::ChatCompletionAggregator,
        },
17
        completions::NvCreateCompletionResponse,
18
    },
19
};
20
use futures::StreamExt;
21

22
23
24
25
26
27
28
fn get_text(content: &ChatCompletionMessageContent) -> &str {
    match content {
        ChatCompletionMessageContent::Text(text) => text.as_str(),
        ChatCompletionMessageContent::Parts(_) => "",
    }
}

29
30
31
32
33
34
35
36
37
38
39
40
41
42
const CMPL_ROOT_PATH: &str = "tests/data/replays/meta/llama-3.1-8b-instruct/completions";
const CHAT_ROOT_PATH: &str = "tests/data/replays/meta/llama-3.1-8b-instruct/chat_completions";

fn create_stream(root_path: &str, file_name: &str) -> DataStream<Result<Message, SseCodecError>> {
    let data = std::fs::read_to_string(format!("{}/{}", root_path, file_name)).unwrap();
    create_message_stream(&data)
}

#[tokio::test]
async fn test_openai_chat_stream() {
    let data = std::fs::read_to_string("tests/data/replays/meta/llama-3.1-8b-instruct/chat_completions/chat-completion.streaming.1").unwrap();

    // note: we are only taking the first 16 messages to keep the size of the response small
    let stream = create_message_stream(&data).take(16);
43
44
45
46
47
48
    let result = NvCreateChatCompletionResponse::from_sse_stream(
        Box::pin(stream),
        ParsingOptions::default(),
    )
    .await
    .unwrap();
49
50
51

    // todo: provide a cleaner way to extract the content from choices
    assert_eq!(
52
53
        get_text(
            result
54
                .inner
55
56
57
58
59
60
61
62
                .choices
                .first()
                .unwrap()
                .message
                .content
                .as_ref()
                .expect("there to be content")
        ),
63
64
65
66
67
68
69
        "Deep learning is a subfield of machine learning that involves the use of artificial"
    );
}

#[tokio::test]
async fn test_openai_chat_edge_case_multi_line_data() {
    let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-multi-line-data");
70
71
72
73
74
75
    let result = NvCreateChatCompletionResponse::from_sse_stream(
        Box::pin(stream),
        ParsingOptions::default(),
    )
    .await
    .unwrap();
76

Paul Hendricks's avatar
Paul Hendricks committed
77
    assert_eq!(
78
79
        get_text(
            result
80
                .inner
81
82
83
84
85
86
87
88
89
                .choices
                .first()
                .unwrap()
                .message
                .content
                .as_ref()
                .expect("there to be content")
        ),
        "Deep learning"
Paul Hendricks's avatar
Paul Hendricks committed
90
    );
91
92
93
94
95
}

#[tokio::test]
async fn test_openai_chat_edge_case_comments_per_response() {
    let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-comments_per_response");
96
97
98
99
100
101
    let result = NvCreateChatCompletionResponse::from_sse_stream(
        Box::pin(stream),
        ParsingOptions::default(),
    )
    .await
    .unwrap();
102

Paul Hendricks's avatar
Paul Hendricks committed
103
    assert_eq!(
104
105
        get_text(
            result
106
                .inner
107
108
109
110
111
112
113
114
115
                .choices
                .first()
                .unwrap()
                .message
                .content
                .as_ref()
                .expect("there to be content")
        ),
        "Deep learning"
Paul Hendricks's avatar
Paul Hendricks committed
116
    );
117
118
119
120
121
}

#[tokio::test]
async fn test_openai_chat_edge_case_invalid_deserialize_error() {
    let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/invalid-deserialize_error");
122
123
124
125
126
    let result = NvCreateChatCompletionResponse::from_sse_stream(
        Box::pin(stream),
        ParsingOptions::default(),
    )
    .await;
127
128
129
130
131
132
133
134
135
136
137
138

    assert!(result.is_err());
    // insta::assert_debug_snapshot!(result);
}

// =============================
// Completions (/v1/completions)
// =============================

#[tokio::test]
async fn test_openai_cmpl_stream() {
    let stream = create_stream(CMPL_ROOT_PATH, "completion.streaming.1").take(16);
139
140
141
142
    let result =
        NvCreateCompletionResponse::from_sse_stream(Box::pin(stream), ParsingOptions::default())
            .await
            .unwrap();
143
144
145

    // todo: provide a cleaner way to extract the content from choices
    assert_eq!(
146
        result.inner.choices.first().unwrap().content(),
147
148
149
        " This is a question that is often asked by those outside of AI research and development"
    );
}
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

// ===================================
// nvext aggregation regression tests
// ===================================

#[allow(deprecated)]
fn make_stream_delta(
    content: Option<&str>,
    nvext: Option<serde_json::Value>,
) -> Annotated<NvCreateChatCompletionStreamResponse> {
    Annotated::from_data(NvCreateChatCompletionStreamResponse {
        inner: CreateChatCompletionStreamResponse {
            id: "test-id".to_string(),
            choices: if let Some(text) = content {
                vec![ChatChoiceStream {
                    index: 0,
                    delta: ChatCompletionStreamResponseDelta {
                        content: Some(ChatCompletionMessageContent::Text(text.to_string())),
                        function_call: None,
                        tool_calls: None,
                        role: Some(Role::Assistant),
                        refusal: None,
                        reasoning_content: None,
                    },
                    finish_reason: None,
                    stop_reason: None,
                    logprobs: None,
                }]
            } else {
                vec![]
            },
            created: 1234567890,
            model: "test-model".to_string(),
            service_tier: None,
            system_fingerprint: None,
            object: "chat.completion.chunk".to_string(),
            usage: None,
        },
        nvext,
    })
}

/// Verify that nvext set on a stream delta survives aggregation into the final response.
#[tokio::test]
async fn test_nvext_passthrough_aggregation() {
    let nvext_value = serde_json::json!({"custom_field": "test_value"});

    let deltas = vec![
        make_stream_delta(Some("Hello"), None),
        make_stream_delta(Some(" world"), Some(nvext_value.clone())),
        make_stream_delta(Some("!"), None),
    ];

    let stream = futures::stream::iter(deltas);
    let result =
        NvCreateChatCompletionResponse::from_annotated_stream(stream, ParsingOptions::default())
            .await
            .unwrap();

    assert_eq!(result.nvext, Some(nvext_value));
    assert_eq!(
        get_text(
            result
                .inner
                .choices
                .first()
                .unwrap()
                .message
                .content
                .as_ref()
                .unwrap()
        ),
        "Hello world!"
    );
}

/// Verify that the last non-None nvext wins when multiple deltas carry nvext.
#[tokio::test]
async fn test_nvext_last_value_wins() {
    let first_nvext = serde_json::json!({"version": 1});
    let last_nvext = serde_json::json!({"version": 2});

    let deltas = vec![
        make_stream_delta(Some("a"), Some(first_nvext)),
        make_stream_delta(Some("b"), None),
        make_stream_delta(Some("c"), Some(last_nvext.clone())),
    ];

    let stream = futures::stream::iter(deltas);
    let result =
        NvCreateChatCompletionResponse::from_annotated_stream(stream, ParsingOptions::default())
            .await
            .unwrap();

    assert_eq!(result.nvext, Some(last_nvext));
}

/// Verify that nvext remains None when no delta carries it.
#[tokio::test]
async fn test_nvext_none_when_absent() {
    let deltas = vec![make_stream_delta(Some("hello"), None)];

    let stream = futures::stream::iter(deltas);
    let result =
        NvCreateChatCompletionResponse::from_annotated_stream(stream, ParsingOptions::default())
            .await
            .unwrap();

    assert_eq!(result.nvext, None);
}