delta.rs 5.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
17
18
use crate::protocols::common;

19
impl NvCreateChatCompletionRequest {
20
21
22
23
24
    // put this method on the request
    // inspect the request to extract options
    pub fn response_generator(&self) -> DeltaGenerator {
        let options = DeltaGeneratorOptions {
            enable_usage: true,
Paul Hendricks's avatar
Paul Hendricks committed
25
            enable_logprobs: self.inner.logprobs.unwrap_or(false),
26
27
        };

Paul Hendricks's avatar
Paul Hendricks committed
28
        DeltaGenerator::new(self.inner.model.clone(), options)
29
30
31
32
33
34
35
36
37
38
39
40
41
    }
}

#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
    pub enable_usage: bool,
    pub enable_logprobs: bool,
}

#[derive(Debug, Clone)]
pub struct DeltaGenerator {
    id: String,
    object: String,
Paul Hendricks's avatar
Paul Hendricks committed
42
    created: u32,
43
44
    model: String,
    system_fingerprint: Option<String>,
Paul Hendricks's avatar
Paul Hendricks committed
45
46
    service_tier: Option<async_openai::types::ServiceTierResponse>,
    usage: async_openai::types::CompletionUsage,
47
48
49
50
51
52
53
54
55

    // counter on how many messages we have issued
    msg_counter: u64,

    options: DeltaGeneratorOptions,
}

impl DeltaGenerator {
    pub fn new(model: String, options: DeltaGeneratorOptions) -> Self {
Paul Hendricks's avatar
Paul Hendricks committed
56
57
58
59
        // SAFETY: This is a fun one to write. We are casting from u64 to u32
        // which typically is unsafe due to loss of precision after it
        // exceeds u32::MAX. Fortunately, this won't be an issue until
        // 2106. So whoever is still maintaining this then, enjoy!
60
61
62
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
Paul Hendricks's avatar
Paul Hendricks committed
63
64
65
66
67
68
69
70
71
            .as_secs() as u32;

        let usage = async_openai::types::CompletionUsage {
            prompt_tokens: 0,
            completion_tokens: 0,
            total_tokens: 0,
            prompt_tokens_details: None,
            completion_tokens_details: None,
        };
72
73
74
75
76
77
78
79

        Self {
            id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
            object: "chat.completion.chunk".to_string(),
            created: now,
            model,
            system_fingerprint: None,
            service_tier: None,
Paul Hendricks's avatar
Paul Hendricks committed
80
            usage,
81
82
83
84
85
            msg_counter: 0,
            options,
        }
    }

Paul Hendricks's avatar
Paul Hendricks committed
86
    pub fn update_isl(&mut self, isl: u32) {
87
88
89
        self.usage.prompt_tokens = isl;
    }

Paul Hendricks's avatar
Paul Hendricks committed
90
    #[allow(deprecated)]
91
92
    pub fn create_choice(
        &self,
Paul Hendricks's avatar
Paul Hendricks committed
93
        index: u32,
94
        text: Option<String>,
Paul Hendricks's avatar
Paul Hendricks committed
95
96
97
98
99
100
        finish_reason: Option<async_openai::types::FinishReason>,
        logprobs: Option<async_openai::types::ChatChoiceLogprobs>,
    ) -> async_openai::types::CreateChatCompletionStreamResponse {
        // TODO: Update for tool calling
        // ALLOW: function_call is deprecated
        let delta = async_openai::types::ChatCompletionStreamResponseDelta {
101
            role: if self.msg_counter == 0 {
Paul Hendricks's avatar
Paul Hendricks committed
102
                Some(async_openai::types::Role::Assistant)
103
104
105
            } else {
                None
            },
Paul Hendricks's avatar
Paul Hendricks committed
106
            content: text,
107
            tool_calls: None,
Paul Hendricks's avatar
Paul Hendricks committed
108
109
            function_call: None,
            refusal: None,
110
111
        };

Paul Hendricks's avatar
Paul Hendricks committed
112
113
114
115
116
117
118
119
120
121
        let choice = async_openai::types::ChatChoiceStream {
            index,
            delta,
            finish_reason,
            logprobs,
        };

        let choices = vec![choice];

        async_openai::types::CreateChatCompletionStreamResponse {
122
123
124
125
126
            id: self.id.clone(),
            object: self.object.clone(),
            created: self.created,
            model: self.model.clone(),
            system_fingerprint: self.system_fingerprint.clone(),
Paul Hendricks's avatar
Paul Hendricks committed
127
            choices,
128
129
130
131
132
133
134
135
136
137
            usage: if self.options.enable_usage {
                Some(self.usage.clone())
            } else {
                None
            },
            service_tier: self.service_tier.clone(),
        }
    }
}

138
139
140
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse>
    for DeltaGenerator
{
141
142
143
    fn choice_from_postprocessor(
        &mut self,
        delta: crate::protocols::common::llm_backend::BackendOutput,
144
    ) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
145
146
        // aggregate usage
        if self.options.enable_usage {
Paul Hendricks's avatar
Paul Hendricks committed
147
            self.usage.completion_tokens += delta.token_ids.len() as u32;
148
149
150
151
152
153
        }

        // todo logprobs
        let logprobs = None;

        let finish_reason = match delta.finish_reason {
Paul Hendricks's avatar
Paul Hendricks committed
154
155
156
157
            Some(common::FinishReason::EoS) => Some(async_openai::types::FinishReason::Stop),
            Some(common::FinishReason::Stop) => Some(async_openai::types::FinishReason::Stop),
            Some(common::FinishReason::Length) => Some(async_openai::types::FinishReason::Length),
            Some(common::FinishReason::Cancelled) => Some(async_openai::types::FinishReason::Stop),
158
159
160
161
162
163
164
165
            Some(common::FinishReason::Error(err_msg)) => {
                return Err(anyhow::anyhow!(err_msg));
            }
            None => None,
        };

        // create choice
        let index = 0;
Paul Hendricks's avatar
Paul Hendricks committed
166
167
        let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);

168
        Ok(NvCreateChatCompletionStreamResponse {
Paul Hendricks's avatar
Paul Hendricks committed
169
170
            inner: stream_response,
        })
171
172
    }
}