openai_completions.rs 3.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
use async_openai::types::CreateCompletionRequestArgs;
17
use serde::{Deserialize, Serialize};
18
use triton_distributed_llm::protocols::openai::{self, completions::CompletionRequest};
19
20
21
22
23
24
25
26
27
28

#[derive(Serialize, Deserialize, Debug, Clone)]
struct CompletionSample {
    request: CompletionRequest,
    description: String,
}

impl CompletionSample {
    fn new<F>(description: impl Into<String>, configure: F) -> Result<Self, String>
    where
29
        F: FnOnce(&mut CreateCompletionRequestArgs) -> &mut CreateCompletionRequestArgs,
30
    {
31
        let mut builder = CreateCompletionRequestArgs::default();
32
33
34
35
        builder
            .model("gpt-3.5-turbo")
            .prompt("What is the meaning of life?");
        configure(&mut builder);
36
37
38
39
40

        let inner = builder.build().unwrap();

        let request = CompletionRequest { inner, nvext: None };

41
        Ok(Self {
42
            request,
43
44
45
46
47
48
49
            description: description.into(),
        })
    }
}

#[test]
fn minimum_viable_request() {
50
    let request = CreateCompletionRequestArgs::default()
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        .prompt("What is the meaning of life?")
        .model("gpt-3.5-turbo")
        .build()
        .expect("error building request");

    insta::assert_json_snapshot!(request);
}

#[test]
fn valid_samples() {
    let mut settings = insta::Settings::clone_current();
    settings.set_sort_maps(true);
    let _guard = settings.bind_to_scope();

    let samples = build_samples().expect("error building samples");

    // iteration on all sample and call validate and expect it to be ok
    for sample in &samples {
        insta::with_settings!({
            description => &sample.description,
        }, {
        insta::assert_json_snapshot!(sample.request);
        });
    }
}
#[allow(clippy::vec_init_then_push)]
fn build_samples() -> Result<Vec<CompletionSample>, String> {
    let mut samples = Vec::new();

    samples.push(CompletionSample::new(
        "should have only prompt and model fields",
        |builder| builder,
    )?);

    samples.push(CompletionSample::new(
        "should have prompt, model, and max_tokens fields",
Paul Hendricks's avatar
Paul Hendricks committed
87
        |builder| builder.max_tokens(10_u32),
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    )?);

    samples.push(CompletionSample::new(
        "should have prompt, model, and temperature fields",
        |builder| builder.temperature(openai::MIN_TEMPERATURE),
    )?);

    samples.push(CompletionSample::new(
        "should have prompt, model, and top_p fields",
        |builder| builder.top_p(openai::MIN_TOP_P),
    )?);

    samples.push(CompletionSample::new(
        "should have prompt, model, and frequency_penalty fields",
        |builder| builder.frequency_penalty(openai::MIN_FREQUENCY_PENALTY),
    )?);

    samples.push(CompletionSample::new(
        "should have prompt, model, and presence_penalty fields",
        |builder| builder.presence_penalty(openai::MIN_PRESENCE_PENALTY),
    )?);

    samples.push(CompletionSample::new(
        "should have prompt, model, and stop fields",
        |builder| builder.stop(vec!["\n".to_string()]),
    )?);

    samples.push(CompletionSample::new(
        "should have prompt, model, and echo fields",
        |builder| builder.echo(true),
    )?);

    samples.push(CompletionSample::new(
        "should have prompt, model, and stream fields",
        |builder| builder.stream(true),
    )?);

    Ok(samples)
}