mistralrs.rs 12.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::{cmp::min, num::NonZero, path::Path, sync::Arc};

Paul Hendricks's avatar
Paul Hendricks committed
18
use async_openai::types::FinishReason;
19
20
21
22
23
use async_stream::stream;
use async_trait::async_trait;
use either::Either;
use indexmap::IndexMap;
use mistralrs::{
24
    AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
25
    GGUFLoaderBuilder, GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder,
26
27
    ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
    Pipeline, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, TokenSource,
28
29
30
};
use tokio::sync::mpsc::channel;

31
32
33
34
use dynemo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynemo_runtime::pipeline::error as pipeline_error;
use dynemo_runtime::pipeline::{Error, ManyOut, SingleIn};
use dynemo_runtime::protocols::annotated::Annotated;
35
36

use crate::protocols::openai::chat_completions::{
37
    NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
38
39
40
41
42
43
};
use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;

/// If user does not provide a max_tokens limit prompt+output to this many
const DEFAULT_MAX_TOKENS: i32 = 8192;

44
45
46
47
48
49
50
51
52
/// TODO: tune. Presumably we read it from model's config.json?
const MAX_SEQ_LEN: usize = 4096;

// TODO: tune, maybe implement batching.
const MAX_BATCH_SIZE: usize = 2;

/// TODO: tune
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 5;

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
pub async fn make_engine(
    gguf_path: &Path,
) -> pipeline_error::Result<OpenAIChatCompletionsStreamingEngine> {
    let engine = MistralRsEngine::new(gguf_path).await?;
    let engine: OpenAIChatCompletionsStreamingEngine = Arc::new(engine);
    Ok(engine)
}

/// Gets the best device, cpu, cuda if compiled with CUDA
fn best_device() -> pipeline_error::Result<Device> {
    #[cfg(not(feature = "metal"))]
    {
        Ok(Device::cuda_if_available(0)?)
    }
    #[cfg(feature = "metal")]
    {
        Ok(Device::new_metal(0)?)
    }
}

struct MistralRsEngine {
    mistralrs: Arc<MistralRs>,
    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync + 'static>>,
}

impl MistralRsEngine {
    async fn new(model_path: &Path) -> pipeline_error::Result<Self> {
80
81
82
83
84
85
86
87
        let loader = if model_path.is_file() {
            // Load from a GGUF
            let Some(model_filename) = model_path.file_name() else {
                pipeline_error::bail!("Missing filename in model path");
            };
            let Some(model_dir) = model_path.parent() else {
                pipeline_error::bail!("Invalid model path");
            };
88

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
            GGUFLoaderBuilder::new(
                None,
                None,
                model_dir.display().to_string(),
                vec![model_filename.to_string_lossy().into_owned()],
                GGUFSpecificConfig {
                    prompt_chunksize: None,
                    topology: None,
                },
            )
            .build()
        } else {
            // Load from a HF repo dir
            NormalLoaderBuilder::new(
                NormalSpecificConfig {
                    use_flash_attn: false,
                    prompt_chunksize: None,
                    topology: None,
                    organization: Default::default(),
                    write_uqff: None,
                    from_uqff: None,
                    imatrix: None,
                    calibration_file: None,
                },
                None,
                None,
                Some(model_path.display().to_string()),
            )
            .build(None)?
        };
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

        // Paged attention requires cuda
        let paged_attention_config = if cfg!(feature = "cuda") {
            Some(PagedAttentionConfig::new(
                Some(32),
                1024,
                MemoryGpuConfig::Utilization(0.9),
            )?)
        } else {
            None
        };
        // Load, into a Pipeline
        let pipeline = loader.load_model_from_hf(
            None,
            TokenSource::CacheToken,
            &ModelDType::Auto,
            &best_device()?,
            false,
137
138
139
140
            DeviceMapSetting::Auto(AutoDeviceMapParams::Text {
                max_seq_len: MAX_SEQ_LEN,
                max_batch_size: MAX_BATCH_SIZE,
            }),
141
142
143
144
145
146
147
148
149
150
151
152
            None,
            paged_attention_config,
        )?;
        let scheduler = if cfg!(feature = "cuda") {
            tracing::debug!("Using mistralrs PagedAttentionMeta scheduler");
            let config = match pipeline.lock().await.get_metadata().cache_config.as_ref() {
                Some(conf) => conf.clone(),
                None => {
                    anyhow::bail!("Failed loading model config");
                }
            };
            SchedulerConfig::PagedAttentionMeta {
153
                max_num_seqs: PAGED_ATTENTION_MAX_NUM_SEQS,
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
                config,
            }
        } else {
            tracing::debug!("Using mistralrs DefaultScheduler");
            SchedulerConfig::DefaultScheduler {
                // Safety: unwrap trivially safe here
                method: DefaultSchedulerMethod::Fixed(NonZero::new(5).unwrap()),
            }
        };
        // Create the MistralRs, which is a runner
        let builder = MistralRsBuilder::new(pipeline.clone(), scheduler);
        Ok(MistralRsEngine {
            mistralrs: builder.build(),
            pipeline,
        })
    }
}

#[async_trait]
impl
    AsyncEngine<
175
        SingleIn<NvCreateChatCompletionRequest>,
176
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
177
178
179
180
181
        Error,
    > for MistralRsEngine
{
    async fn generate(
        &self,
182
        request: SingleIn<NvCreateChatCompletionRequest>,
183
    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
184
185
186
187
188
        let (request, context) = request.transfer(());
        let ctx = context.context();
        let (tx, mut rx) = channel(10_000);
        let maybe_tok = self.pipeline.lock().await.tokenizer();

Paul Hendricks's avatar
Paul Hendricks committed
189
        let mut prompt_tokens = 0i32;
190
        let mut messages = vec![];
Paul Hendricks's avatar
Paul Hendricks committed
191
192
193
194
195
196
        for m in request.inner.messages {
            let async_openai::types::ChatCompletionRequestMessage::User(inner_m) = m else {
                continue;
            };
            let content = match inner_m.content {
                async_openai::types::ChatCompletionRequestUserMessageContent::Text(prompt) => {
197
198
199
200
201
202
203
204
                    if let Some(tok) = maybe_tok.as_ref() {
                        prompt_tokens = tok
                            .encode(prompt.clone(), false)
                            .map(|e| e.len() as i32)
                            .unwrap_or(0);
                    }
                    prompt
                }
Paul Hendricks's avatar
Paul Hendricks committed
205
206
                _ => {
                    anyhow::bail!("Only Text type is supported");
207
208
209
                }
            };
            let r = IndexMap::from([
Paul Hendricks's avatar
Paul Hendricks committed
210
                ("role".to_string(), Either::Left("user".to_string())),
211
212
213
214
215
216
217
218
219
220
221
                ("content".to_string(), Either::Left(content)),
            ]);
            messages.push(r);
        }
        if messages.is_empty() {
            anyhow::bail!("Empty request");
        }
        // TODO tracing::trace print the latest prompt, which should be the last message at user
        // level.
        //tracing::info!(prompt_tokens, "Received prompt");
        let limit = DEFAULT_MAX_TOKENS - prompt_tokens;
Paul Hendricks's avatar
Paul Hendricks committed
222
223
224
225
226
        #[allow(deprecated)]
        let max_output_tokens = min(
            request.inner.max_tokens.map(|x| x as i32).unwrap_or(limit),
            limit,
        );
227
228
229
230
231
232
233

        let mistralrs_request = Request::Normal(NormalRequest {
            messages: RequestMessage::Chat(messages),
            sampling_params: SamplingParams::deterministic(),
            response: tx,
            return_logprobs: false,
            is_streaming: true,
234
            id: self.mistralrs.next_request_id(),
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
            constraint: Constraint::None,
            suffix: None,
            adapters: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
        });

        self.mistralrs.get_sender()?.send(mistralrs_request).await?;

        let mut used_output_tokens = 0;
        let output = stream! {
            while let Some(response) = rx.recv().await {
                let response = match response.as_result() {
                    Ok(r) => r,
                    Err(err) => {
                        tracing::error!(%err, "Failed converting mistralrs channel response to result.");
                        break;
                    }
                };
                match response {
                    ResponseOk::Chunk(c) => {
258
259
260
261
                        let Some(from_assistant) = c.choices[0].delta.content.clone() else {
                            tracing::warn!("No content from mistralrs. Abandoning request.");
                            break;
                        };
262
263
264
265
266
267
268
                        if let Some(tok) = maybe_tok.as_ref() {
                            used_output_tokens += tok
                                .encode(from_assistant.clone(), false)
                                .map(|e| e.len() as i32)
                                .unwrap_or(0);
                        }
                        let finish_reason = match &c.choices[0].finish_reason {
Paul Hendricks's avatar
Paul Hendricks committed
269
                            Some(_fr) => Some(FinishReason::Stop), //Some(fr.parse::<FinishReason>().unwrap_or(FinishReason::Stop)),
270
271
                            None if used_output_tokens >= max_output_tokens => {
                                tracing::debug!(used_output_tokens, max_output_tokens, "Met or exceed max_tokens. Stopping.");
Paul Hendricks's avatar
Paul Hendricks committed
272
                                Some(FinishReason::Length)
273
274
275
276
277
                            }
                            None => None,
                        };
                        //tracing::trace!("from_assistant: {from_assistant}");

Paul Hendricks's avatar
Paul Hendricks committed
278
279
                        #[allow(deprecated)]
                        let inner = async_openai::types::CreateChatCompletionStreamResponse{
280
                            id: c.id,
Paul Hendricks's avatar
Paul Hendricks committed
281
                            choices: vec![async_openai::types::ChatChoiceStream{
282
                                index: 0,
Paul Hendricks's avatar
Paul Hendricks committed
283
                                delta: async_openai::types::ChatCompletionStreamResponseDelta{
284
                                    //role: c.choices[0].delta.role,
Paul Hendricks's avatar
Paul Hendricks committed
285
                                    role: Some(async_openai::types::Role::Assistant),
286
287
                                    content: Some(from_assistant),
                                    tool_calls: None,
Paul Hendricks's avatar
Paul Hendricks committed
288
289
                                    refusal: None,
                                    function_call: None,
290
291
292
293
294
                                },
                                logprobs: None,
                                finish_reason,
                            }],
                            model: c.model,
Paul Hendricks's avatar
Paul Hendricks committed
295
                            created: c.created as u32,
296
297
298
299
300
                            object: c.object.clone(),
                            usage: None,
                            system_fingerprint: Some(c.system_fingerprint),
                            service_tier: None,
                        };
301
                        let delta = NvCreateChatCompletionStreamResponse{inner};
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
                        let ann = Annotated{
                            id: None,
                            data: Some(delta),
                            event: None,
                            comment: None,
                        };
                        yield ann;

                        if finish_reason.is_some() {
                            //tracing::trace!("Finish reason: {finish_reason:?}");
                            break;
                        }
                    },
                    x => tracing::error!("Unhandled. {x:?}"),
                }
            }
        };
        Ok(ResponseStream::new(Box::pin(output), ctx))
    }
}