mistralrs.rs 14.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
use std::collections::HashMap;
17
use std::{num::NonZero, path::Path, sync::Arc};
18

Paul Hendricks's avatar
Paul Hendricks committed
19
use async_openai::types::FinishReason;
20
21
22
23
24
use async_stream::stream;
use async_trait::async_trait;
use either::Either;
use indexmap::IndexMap;
use mistralrs::{
25
    AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
26
    GGUFLoaderBuilder, GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder,
27
    ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
28
    Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, StopTokens, TokenSource,
29
30
31
};
use tokio::sync::mpsc::channel;

Neelay Shah's avatar
Neelay Shah committed
32
33
34
35
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::error as pipeline_error;
use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
use dynamo_runtime::protocols::annotated::Annotated;
36
37

use crate::protocols::openai::chat_completions::{
38
    NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
39
40
41
};
use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;

42
43
44
45
46
47
48
49
50
/// How many requests mistral will run at once in the paged attention scheduler.
/// It actually runs 1 fewer than this.
/// I would call this the batch size but apparently that's something else.
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 10;

/// Experimental: Switch this to true to enable paged attention on CUDA devices.
/// Under load (dynamo-run batch mode) paged attention sometimes returns an immediate
/// finish_reason=stop and no tokens for one of the requests.
const EXP_ENABLE_PAGED_ATTENTION: bool = false;
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
pub async fn make_engine(
    gguf_path: &Path,
) -> pipeline_error::Result<OpenAIChatCompletionsStreamingEngine> {
    let engine = MistralRsEngine::new(gguf_path).await?;
    let engine: OpenAIChatCompletionsStreamingEngine = Arc::new(engine);
    Ok(engine)
}

/// Gets the best device, cpu, cuda if compiled with CUDA
fn best_device() -> pipeline_error::Result<Device> {
    #[cfg(not(feature = "metal"))]
    {
        Ok(Device::cuda_if_available(0)?)
    }
    #[cfg(feature = "metal")]
    {
        Ok(Device::new_metal(0)?)
    }
}

struct MistralRsEngine {
    mistralrs: Arc<MistralRs>,
}

impl MistralRsEngine {
    async fn new(model_path: &Path) -> pipeline_error::Result<Self> {
78
79
80
81
82
83
84
85
        let loader = if model_path.is_file() {
            // Load from a GGUF
            let Some(model_filename) = model_path.file_name() else {
                pipeline_error::bail!("Missing filename in model path");
            };
            let Some(model_dir) = model_path.parent() else {
                pipeline_error::bail!("Invalid model path");
            };
86

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
            GGUFLoaderBuilder::new(
                None,
                None,
                model_dir.display().to_string(),
                vec![model_filename.to_string_lossy().into_owned()],
                GGUFSpecificConfig {
                    prompt_chunksize: None,
                    topology: None,
                },
            )
            .build()
        } else {
            // Load from a HF repo dir
            NormalLoaderBuilder::new(
                NormalSpecificConfig {
                    use_flash_attn: false,
                    prompt_chunksize: None,
                    topology: None,
                    organization: Default::default(),
                    write_uqff: None,
                    from_uqff: None,
                    imatrix: None,
                    calibration_file: None,
                },
                None,
                None,
                Some(model_path.display().to_string()),
            )
            .build(None)?
        };
117

118
119
        let max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;

120
        // Paged attention requires cuda
121
        let paged_attention_config = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
122
            Some(PagedAttentionConfig::new(
123
                None, // Block size, default 32
124
                4096, // CPU memory in MiB
125
                MemoryGpuConfig::ContextSize(max_seq_len),
126
127
128
129
130
131
132
            )?)
        } else {
            None
        };
        // Load, into a Pipeline
        let pipeline = loader.load_model_from_hf(
            None,
133
            TokenSource::None, // The model was already downloaded
134
135
136
            &ModelDType::Auto,
            &best_device()?,
            false,
137
            DeviceMapSetting::Auto(AutoDeviceMapParams::Text {
138
139
                max_seq_len,
                max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
140
            }),
141
142
143
            None,
            paged_attention_config,
        )?;
144
        let scheduler = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
145
146
147
148
149
150
151
152
            tracing::debug!("Using mistralrs PagedAttentionMeta scheduler");
            let config = match pipeline.lock().await.get_metadata().cache_config.as_ref() {
                Some(conf) => conf.clone(),
                None => {
                    anyhow::bail!("Failed loading model config");
                }
            };
            SchedulerConfig::PagedAttentionMeta {
153
                max_num_seqs: PAGED_ATTENTION_MAX_NUM_SEQS,
154
155
156
157
158
159
                config,
            }
        } else {
            tracing::debug!("Using mistralrs DefaultScheduler");
            SchedulerConfig::DefaultScheduler {
                // Safety: unwrap trivially safe here
160
                method: DefaultSchedulerMethod::Fixed(NonZero::new(max_seq_len).unwrap()),
161
162
163
            }
        };
        // Create the MistralRs, which is a runner
164
        let builder = MistralRsBuilder::new(pipeline.clone(), scheduler).with_prefix_cache_n(16);
165
        let engine = MistralRsEngine {
166
            mistralrs: builder.build(),
167
168
169
170
        };
        // skip the id used for dummy run https://github.com/EricLBuehler/mistral.rs/issues/1218
        let _ = engine.mistralrs.next_request_id();
        Ok(engine)
171
172
173
174
175
176
    }
}

#[async_trait]
impl
    AsyncEngine<
177
        SingleIn<NvCreateChatCompletionRequest>,
178
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
179
180
181
182
183
        Error,
    > for MistralRsEngine
{
    async fn generate(
        &self,
184
        request: SingleIn<NvCreateChatCompletionRequest>,
185
    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
186
187
188
189
190
        let (request, context) = request.transfer(());
        let ctx = context.context();
        let (tx, mut rx) = channel(10_000);

        let mut messages = vec![];
Paul Hendricks's avatar
Paul Hendricks committed
191
192
193
194
        for m in request.inner.messages {
            let async_openai::types::ChatCompletionRequestMessage::User(inner_m) = m else {
                continue;
            };
195
196
197
198
            let async_openai::types::ChatCompletionRequestUserMessageContent::Text(content) =
                inner_m.content
            else {
                anyhow::bail!("Only Text type chat completion supported");
199
200
            };
            let r = IndexMap::from([
Paul Hendricks's avatar
Paul Hendricks committed
201
                ("role".to_string(), Either::Left("user".to_string())),
202
203
204
205
206
207
208
209
                ("content".to_string(), Either::Left(content)),
            ]);
            messages.push(r);
        }
        if messages.is_empty() {
            anyhow::bail!("Empty request");
        }

210
        let det = SamplingParams::deterministic();
211
212
        // allow deprecated because max_tokens
        #[allow(deprecated)]
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        let sampling_params = SamplingParams {
            temperature: request
                .inner
                .temperature
                .map(|t| t as f64)
                .or(det.temperature),
            top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
            top_n_logprobs: request
                .inner
                .top_logprobs
                .map(|t| t as usize)
                .unwrap_or(det.top_n_logprobs),
            frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
            presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
            stop_toks: request.inner.stop.map(to_stop_tokens).or(det.stop_toks),
            max_len: request
                .inner
                .max_completion_tokens
231
                .or(request.inner.max_tokens)
232
233
234
235
236
237
238
239
240
241
242
243
244
                .map(|m| m as usize)
                .or(det.max_len),
            logits_bias: request
                .inner
                .logit_bias
                .map(to_logit_bias)
                .or(det.logits_bias),
            // These are not in async-openai yet
            top_k: det.top_k,
            min_p: det.min_p,
            n_choices: 1,
            dry_params: det.dry_params,
        };
245
        let request_id = self.mistralrs.next_request_id();
246
        let mistralrs_request = Request::Normal(NormalRequest {
247
            id: request_id,
248
            messages: RequestMessage::Chat(messages),
249
            sampling_params,
250
            response: tx,
251
            return_logprobs: request.inner.logprobs.unwrap_or_default(),
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
            is_streaming: true,
            constraint: Constraint::None,
            suffix: None,
            adapters: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
        });

        self.mistralrs.get_sender()?.send(mistralrs_request).await?;

        let output = stream! {
            while let Some(response) = rx.recv().await {
                let response = match response.as_result() {
                    Ok(r) => r,
                    Err(err) => {
269
                        tracing::error!(request_id, %err, "Failed converting mistralrs channel response to result.");
270
271
272
273
274
                        break;
                    }
                };
                match response {
                    ResponseOk::Chunk(c) => {
275
                        let Some(from_assistant) = c.choices[0].delta.content.clone() else {
276
                            tracing::warn!(request_id, "No content from mistralrs. Abandoning request.");
277
278
                            break;
                        };
279
280
281
282
283
                        let finish_reason = match &c.choices[0].finish_reason.as_deref() {
                            Some("stop") | Some("canceled") => {
                                Some(FinishReason::Stop)
                            }
                            Some("length") => {
Paul Hendricks's avatar
Paul Hendricks committed
284
                                Some(FinishReason::Length)
285
                            }
286
                            Some(s) => {
287
                                tracing::warn!(request_id, stop_reason = s, "Unknow stop reason");
288
289
                                Some(FinishReason::Stop)
                            }
290
291
292
293
                            None => None,
                        };
                        //tracing::trace!("from_assistant: {from_assistant}");

Paul Hendricks's avatar
Paul Hendricks committed
294
295
                        #[allow(deprecated)]
                        let inner = async_openai::types::CreateChatCompletionStreamResponse{
296
                            id: c.id,
Paul Hendricks's avatar
Paul Hendricks committed
297
                            choices: vec![async_openai::types::ChatChoiceStream{
298
                                index: 0,
Paul Hendricks's avatar
Paul Hendricks committed
299
                                delta: async_openai::types::ChatCompletionStreamResponseDelta{
300
                                    //role: c.choices[0].delta.role,
Paul Hendricks's avatar
Paul Hendricks committed
301
                                    role: Some(async_openai::types::Role::Assistant),
302
303
                                    content: Some(from_assistant),
                                    tool_calls: None,
Paul Hendricks's avatar
Paul Hendricks committed
304
305
                                    refusal: None,
                                    function_call: None,
306
307
308
309
310
                                },
                                logprobs: None,
                                finish_reason,
                            }],
                            model: c.model,
Paul Hendricks's avatar
Paul Hendricks committed
311
                            created: c.created as u32,
312
313
314
315
316
                            object: c.object.clone(),
                            usage: None,
                            system_fingerprint: Some(c.system_fingerprint),
                            service_tier: None,
                        };
317
                        let delta = NvCreateChatCompletionStreamResponse{inner};
318
319
320
321
322
323
324
325
326
                        let ann = Annotated{
                            id: None,
                            data: Some(delta),
                            event: None,
                            comment: None,
                        };
                        yield ann;

                        if finish_reason.is_some() {
327
                            //tracing::trace!(request_id, "Finish reason: {finish_reason:?}");
328
329
330
                            break;
                        }
                    },
331
                    x => tracing::error!(request_id, "Unhandled. {x:?}"),
332
333
334
335
336
337
                }
            }
        };
        Ok(ResponseStream::new(Box::pin(output), ctx))
    }
}
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368

/// openai stop tokens to mistralrs stop tokens
fn to_stop_tokens(t: async_openai::types::Stop) -> StopTokens {
    match t {
        async_openai::types::Stop::String(s) => StopTokens::Seqs(vec![s]),
        async_openai::types::Stop::StringArray(v) => StopTokens::Seqs(v),
    }
}

/// openai logit bias (strings/json) to mistralrs (u32/f32)
/// I think the input looks like this: {"3721": -100, "17765": 100}
fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
    let mut out = HashMap::new();
    for (key, value) in &lb {
        let token_id: u32 = match key.parse() {
            Ok(t) => t,
            Err(err) => {
                tracing::warn!(
                    "Unexpected logit_bias map. Key '{key}' is not an int: {lb:?}. {err}."
                );
                return HashMap::new();
            }
        };
        let Some(bias) = value.as_f64() else {
            tracing::warn!("Unexpected logit_bias map. Value '{value}' is not a float: {lb:?}");
            return HashMap::new();
        };
        out.insert(token_id, bias as f32);
    }
    out
}