tool_choice_finish_reasons.rs 8.77 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// SPDX-License-Identifier: Apache-2.0

//! Tests for tool_choice finish_reason handling.

use dynamo_async_openai::types::{
    ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
    ChatCompletionRequestUserMessageContent, ChatCompletionToolChoiceOption,
    ChatCompletionToolType, CreateChatCompletionRequest, FunctionName,
};
use dynamo_llm::protocols::common;
use dynamo_llm::protocols::common::llm_backend::BackendOutput;
use dynamo_llm::protocols::openai::DeltaGeneratorExt;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;

fn create_test_request() -> NvCreateChatCompletionRequest {
    let messages = vec![ChatCompletionRequestMessage::User(
        ChatCompletionRequestUserMessage {
            content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
            name: None,
        },
    )];

    NvCreateChatCompletionRequest {
        inner: CreateChatCompletionRequest {
            model: "test-model".to_string(),
            messages,
            stream: Some(false),
            stream_options: None,
            ..Default::default()
        },
        common: Default::default(),
        nvext: None,
        chat_template_args: None,
35
        media_io_kwargs: None,
36
37
38
39
40
41
42
43
44
45
46
47
48
        unsupported_fields: Default::default(),
    }
}

fn build_backend_output_with_finish(text: &str, finish: common::FinishReason) -> BackendOutput {
    BackendOutput {
        token_ids: vec![],
        tokens: vec![],
        text: Some(text.to_string()),
        cum_log_probs: None,
        log_probs: None,
        top_logprobs: None,
        finish_reason: Some(finish),
49
        stop_reason: None,
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        index: Some(0),
        completion_usage: None,
        disaggregated_params: None,
    }
}

async fn apply_jail_transformation(
    raw_response: dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse,
    tool_choice: Option<ChatCompletionToolChoiceOption>,
) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse {
    use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream;
    use dynamo_runtime::protocols::annotated::Annotated;
    use futures::StreamExt;
    use futures::stream;

    let input_stream = stream::iter(vec![Annotated {
        data: Some(raw_response),
        id: None,
        event: None,
        comment: None,
70
        error: None,
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    }]);

    let mut builder = JailedStream::builder();

    match tool_choice {
        Some(ChatCompletionToolChoiceOption::Named(ref named)) => {
            builder = builder.tool_choice_named(named.function.name.clone());
        }
        Some(ChatCompletionToolChoiceOption::Required) => {
            builder = builder.tool_choice_required();
        }
        _ => {}
    }

    let jail = builder.build();
    let output_stream = jail.apply_with_finish_reason(input_stream);

    tokio::pin!(output_stream);
    output_stream.next().await.unwrap().data.unwrap()
}

#[tokio::test]
async fn test_named_tool_choice_preserves_length_finish_reason() {
    let mut request = create_test_request();
    let tool_choice = Some(ChatCompletionToolChoiceOption::Named(
        ChatCompletionNamedToolChoice {
            r#type: ChatCompletionToolType::Function,
            function: FunctionName {
                name: "get_weather".to_string(),
            },
        },
    ));
    request.inner.tool_choice = tool_choice.clone();

    let mut generator = request.response_generator("req-length-1".to_string());
    let backend_output = build_backend_output_with_finish(
        r#"{"location":"Par"#, // Incomplete due to length limit
        common::FinishReason::Length,
    );

    let raw_response = generator
        .choice_from_postprocessor(backend_output)
        .expect("choice generation");

    let response = apply_jail_transformation(raw_response, tool_choice).await;

    // Critical: Length finish reason should be preserved, NOT replaced with Stop
    assert_eq!(
119
        response.inner.choices[0].finish_reason,
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        Some(dynamo_async_openai::types::FinishReason::Length),
        "Length finish reason must be preserved for tool_choice=named"
    );
}

#[test]
fn test_required_tool_choice_preserves_length_finish_reason() {
    let mut request = create_test_request();
    request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required);

    let mut generator = request.response_generator("req-length-2".to_string());
    let backend_output = build_backend_output_with_finish(
        r#"[{"name":"search","parameters":{"query":"incomplete"#, // Incomplete due to length
        common::FinishReason::Length,
    );

    let response = generator
        .choice_from_postprocessor(backend_output)
        .expect("choice generation");

    // Critical: Length finish reason should be preserved, NOT replaced with ToolCalls
    assert_eq!(
142
        response.inner.choices[0].finish_reason,
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        Some(dynamo_async_openai::types::FinishReason::Length),
        "Length finish reason must be preserved for tool_choice=required"
    );
}

#[test]
fn test_named_tool_choice_preserves_content_filter() {
    let mut request = create_test_request();
    request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named(
        ChatCompletionNamedToolChoice {
            r#type: ChatCompletionToolType::Function,
            function: FunctionName {
                name: "search".to_string(),
            },
        },
    ));

    let mut generator = request.response_generator("req-filter-1".to_string());
    let backend_output = build_backend_output_with_finish(
        r#"{"query":"filtered content"#,
        common::FinishReason::ContentFilter,
    );

    let response = generator
        .choice_from_postprocessor(backend_output)
        .expect("choice generation");

    // Critical: ContentFilter finish reason should be preserved
    assert_eq!(
172
        response.inner.choices[0].finish_reason,
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        Some(dynamo_async_openai::types::FinishReason::ContentFilter),
        "ContentFilter finish reason must be preserved for tool_choice=named"
    );
}

#[test]
fn test_required_tool_choice_preserves_content_filter() {
    let mut request = create_test_request();
    request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required);

    let mut generator = request.response_generator("req-filter-2".to_string());
    let backend_output = build_backend_output_with_finish(
        r#"[{"name":"harmful_action"#,
        common::FinishReason::ContentFilter,
    );

    let response = generator
        .choice_from_postprocessor(backend_output)
        .expect("choice generation");

    // Critical: ContentFilter finish reason should be preserved
    assert_eq!(
195
        response.inner.choices[0].finish_reason,
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        Some(dynamo_async_openai::types::FinishReason::ContentFilter),
        "ContentFilter finish reason must be preserved for tool_choice=required"
    );
}

#[test]
fn test_named_tool_choice_normal_stop_becomes_stop() {
    let mut request = create_test_request();
    request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named(
        ChatCompletionNamedToolChoice {
            r#type: ChatCompletionToolType::Function,
            function: FunctionName {
                name: "get_weather".to_string(),
            },
        },
    ));

    let mut generator = request.response_generator("req-stop-1".to_string());
    let backend_output = build_backend_output_with_finish(
        r#"{"location":"Paris","unit":"celsius"}"#,
        common::FinishReason::Stop,
    );

    let response = generator
        .choice_from_postprocessor(backend_output)
        .expect("choice generation");

    // Normal completion: Stop should remain Stop for named tool choice
    assert_eq!(
225
        response.inner.choices[0].finish_reason,
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        Some(dynamo_async_openai::types::FinishReason::Stop),
    );
}

#[tokio::test]
async fn test_required_tool_choice_normal_stop_becomes_tool_calls() {
    let mut request = create_test_request();
    let tool_choice = Some(ChatCompletionToolChoiceOption::Required);
    request.inner.tool_choice = tool_choice.clone();

    let mut generator = request.response_generator("req-stop-2".to_string());
    let backend_output = build_backend_output_with_finish(
        r#"[{"name":"search","parameters":{"query":"rust"}}]"#,
        common::FinishReason::Stop,
    );

    let raw_response = generator
        .choice_from_postprocessor(backend_output)
        .expect("choice generation");

    let response = apply_jail_transformation(raw_response, tool_choice).await;

    // Normal completion: Stop should become ToolCalls for required tool choice
    assert_eq!(
250
        response.inner.choices[0].finish_reason,
251
252
253
        Some(dynamo_async_openai::types::FinishReason::ToolCalls),
    );
}