mistralrs.rs 15.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
17
use std::collections::HashMap;
use std::{cmp::min, env, num::NonZero, path::Path, sync::Arc};
18

Paul Hendricks's avatar
Paul Hendricks committed
19
use async_openai::types::FinishReason;
20
21
22
23
24
use async_stream::stream;
use async_trait::async_trait;
use either::Either;
use indexmap::IndexMap;
use mistralrs::{
25
    AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
26
    GGUFLoaderBuilder, GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder,
27
    ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
28
29
    Pipeline, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, StopTokens,
    TokenSource,
30
31
32
};
use tokio::sync::mpsc::channel;

Neelay Shah's avatar
Neelay Shah committed
33
34
35
36
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::error as pipeline_error;
use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
use dynamo_runtime::protocols::annotated::Annotated;
37
38

use crate::protocols::openai::chat_completions::{
39
    NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
40
41
42
43
44
45
};
use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;

/// If user does not provide a max_tokens limit prompt+output to this many
const DEFAULT_MAX_TOKENS: i32 = 8192;

46
47
48
/// TODO: tune
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 5;

49
50
51
/// The environment variable which can hold the Hugging Face token, if any, in order
const HF_TOKEN_VARS: [&str; 3] = ["HF_TOKEN", "HUGGING_FACE_HUB_TOKEN", "HUGGINGFACE_TOKEN"];

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
pub async fn make_engine(
    gguf_path: &Path,
) -> pipeline_error::Result<OpenAIChatCompletionsStreamingEngine> {
    let engine = MistralRsEngine::new(gguf_path).await?;
    let engine: OpenAIChatCompletionsStreamingEngine = Arc::new(engine);
    Ok(engine)
}

/// Gets the best device, cpu, cuda if compiled with CUDA
fn best_device() -> pipeline_error::Result<Device> {
    #[cfg(not(feature = "metal"))]
    {
        Ok(Device::cuda_if_available(0)?)
    }
    #[cfg(feature = "metal")]
    {
        Ok(Device::new_metal(0)?)
    }
}

struct MistralRsEngine {
    mistralrs: Arc<MistralRs>,
    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync + 'static>>,
}

impl MistralRsEngine {
    async fn new(model_path: &Path) -> pipeline_error::Result<Self> {
79
80
81
82
83
84
85
86
87
88
89
        let mut hf_token_source = TokenSource::CacheToken;
        // We might be trying to download a repo from Hugging Face. See if we have a token.
        if !model_path.exists() {
            for v_name in HF_TOKEN_VARS {
                if env::var(v_name).is_ok() {
                    tracing::debug!("Using Hugging Face token from {v_name}");
                    hf_token_source = TokenSource::EnvVar(v_name.to_string());
                    break;
                }
            }
        }
90
91
92
93
94
95
96
97
        let loader = if model_path.is_file() {
            // Load from a GGUF
            let Some(model_filename) = model_path.file_name() else {
                pipeline_error::bail!("Missing filename in model path");
            };
            let Some(model_dir) = model_path.parent() else {
                pipeline_error::bail!("Invalid model path");
            };
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
            GGUFLoaderBuilder::new(
                None,
                None,
                model_dir.display().to_string(),
                vec![model_filename.to_string_lossy().into_owned()],
                GGUFSpecificConfig {
                    prompt_chunksize: None,
                    topology: None,
                },
            )
            .build()
        } else {
            // Load from a HF repo dir
            NormalLoaderBuilder::new(
                NormalSpecificConfig {
                    use_flash_attn: false,
                    prompt_chunksize: None,
                    topology: None,
                    organization: Default::default(),
                    write_uqff: None,
                    from_uqff: None,
                    imatrix: None,
                    calibration_file: None,
                },
                None,
                None,
                Some(model_path.display().to_string()),
            )
            .build(None)?
        };
129

130
131
        let max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;

132
133
134
        // Paged attention requires cuda
        let paged_attention_config = if cfg!(feature = "cuda") {
            Some(PagedAttentionConfig::new(
135
136
137
                None, // Block size, default 32
                512,  // CPU memory in MiB
                MemoryGpuConfig::ContextSize(max_seq_len),
138
139
140
141
142
143
144
            )?)
        } else {
            None
        };
        // Load, into a Pipeline
        let pipeline = loader.load_model_from_hf(
            None,
145
            hf_token_source,
146
147
148
            &ModelDType::Auto,
            &best_device()?,
            false,
149
            DeviceMapSetting::Auto(AutoDeviceMapParams::Text {
150
151
                max_seq_len,
                max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
152
            }),
153
154
155
156
157
158
159
160
161
162
163
164
            None,
            paged_attention_config,
        )?;
        let scheduler = if cfg!(feature = "cuda") {
            tracing::debug!("Using mistralrs PagedAttentionMeta scheduler");
            let config = match pipeline.lock().await.get_metadata().cache_config.as_ref() {
                Some(conf) => conf.clone(),
                None => {
                    anyhow::bail!("Failed loading model config");
                }
            };
            SchedulerConfig::PagedAttentionMeta {
165
                max_num_seqs: PAGED_ATTENTION_MAX_NUM_SEQS,
166
167
168
169
170
171
                config,
            }
        } else {
            tracing::debug!("Using mistralrs DefaultScheduler");
            SchedulerConfig::DefaultScheduler {
                // Safety: unwrap trivially safe here
172
                method: DefaultSchedulerMethod::Fixed(NonZero::new(max_seq_len).unwrap()),
173
174
175
            }
        };
        // Create the MistralRs, which is a runner
176
        let builder = MistralRsBuilder::new(pipeline.clone(), scheduler).with_prefix_cache_n(16);
177
178
179
180
181
182
183
184
185
186
        Ok(MistralRsEngine {
            mistralrs: builder.build(),
            pipeline,
        })
    }
}

#[async_trait]
impl
    AsyncEngine<
187
        SingleIn<NvCreateChatCompletionRequest>,
188
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
189
190
191
192
193
        Error,
    > for MistralRsEngine
{
    async fn generate(
        &self,
194
        request: SingleIn<NvCreateChatCompletionRequest>,
195
    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
196
197
198
199
200
        let (request, context) = request.transfer(());
        let ctx = context.context();
        let (tx, mut rx) = channel(10_000);
        let maybe_tok = self.pipeline.lock().await.tokenizer();

Paul Hendricks's avatar
Paul Hendricks committed
201
        let mut prompt_tokens = 0i32;
202
        let mut messages = vec![];
Paul Hendricks's avatar
Paul Hendricks committed
203
204
205
206
207
208
        for m in request.inner.messages {
            let async_openai::types::ChatCompletionRequestMessage::User(inner_m) = m else {
                continue;
            };
            let content = match inner_m.content {
                async_openai::types::ChatCompletionRequestUserMessageContent::Text(prompt) => {
209
210
211
212
213
214
215
216
                    if let Some(tok) = maybe_tok.as_ref() {
                        prompt_tokens = tok
                            .encode(prompt.clone(), false)
                            .map(|e| e.len() as i32)
                            .unwrap_or(0);
                    }
                    prompt
                }
Paul Hendricks's avatar
Paul Hendricks committed
217
218
                _ => {
                    anyhow::bail!("Only Text type is supported");
219
220
221
                }
            };
            let r = IndexMap::from([
Paul Hendricks's avatar
Paul Hendricks committed
222
                ("role".to_string(), Either::Left("user".to_string())),
223
224
225
226
227
228
229
230
231
232
233
                ("content".to_string(), Either::Left(content)),
            ]);
            messages.push(r);
        }
        if messages.is_empty() {
            anyhow::bail!("Empty request");
        }
        // TODO tracing::trace print the latest prompt, which should be the last message at user
        // level.
        //tracing::info!(prompt_tokens, "Received prompt");
        let limit = DEFAULT_MAX_TOKENS - prompt_tokens;
Paul Hendricks's avatar
Paul Hendricks committed
234
235
236
237
238
        #[allow(deprecated)]
        let max_output_tokens = min(
            request.inner.max_tokens.map(|x| x as i32).unwrap_or(limit),
            limit,
        );
239

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        let det = SamplingParams::deterministic();
        let sampling_params = SamplingParams {
            temperature: request
                .inner
                .temperature
                .map(|t| t as f64)
                .or(det.temperature),
            top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
            top_n_logprobs: request
                .inner
                .top_logprobs
                .map(|t| t as usize)
                .unwrap_or(det.top_n_logprobs),
            frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
            presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
            stop_toks: request.inner.stop.map(to_stop_tokens).or(det.stop_toks),
            max_len: request
                .inner
                .max_completion_tokens
                .map(|m| m as usize)
                .or(det.max_len),
            logits_bias: request
                .inner
                .logit_bias
                .map(to_logit_bias)
                .or(det.logits_bias),
            // These are not in async-openai yet
            top_k: det.top_k,
            min_p: det.min_p,
            n_choices: 1,
            dry_params: det.dry_params,
        };
272
273
        let mistralrs_request = Request::Normal(NormalRequest {
            messages: RequestMessage::Chat(messages),
274
            sampling_params,
275
            response: tx,
276
            return_logprobs: request.inner.logprobs.unwrap_or_default(),
277
            is_streaming: true,
278
            id: self.mistralrs.next_request_id(),
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
            constraint: Constraint::None,
            suffix: None,
            adapters: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
        });

        self.mistralrs.get_sender()?.send(mistralrs_request).await?;

        let mut used_output_tokens = 0;
        let output = stream! {
            while let Some(response) = rx.recv().await {
                let response = match response.as_result() {
                    Ok(r) => r,
                    Err(err) => {
                        tracing::error!(%err, "Failed converting mistralrs channel response to result.");
                        break;
                    }
                };
                match response {
                    ResponseOk::Chunk(c) => {
302
303
304
305
                        let Some(from_assistant) = c.choices[0].delta.content.clone() else {
                            tracing::warn!("No content from mistralrs. Abandoning request.");
                            break;
                        };
306
307
308
309
310
311
312
                        if let Some(tok) = maybe_tok.as_ref() {
                            used_output_tokens += tok
                                .encode(from_assistant.clone(), false)
                                .map(|e| e.len() as i32)
                                .unwrap_or(0);
                        }
                        let finish_reason = match &c.choices[0].finish_reason {
Paul Hendricks's avatar
Paul Hendricks committed
313
                            Some(_fr) => Some(FinishReason::Stop), //Some(fr.parse::<FinishReason>().unwrap_or(FinishReason::Stop)),
314
315
                            None if used_output_tokens >= max_output_tokens => {
                                tracing::debug!(used_output_tokens, max_output_tokens, "Met or exceed max_tokens. Stopping.");
Paul Hendricks's avatar
Paul Hendricks committed
316
                                Some(FinishReason::Length)
317
318
319
320
321
                            }
                            None => None,
                        };
                        //tracing::trace!("from_assistant: {from_assistant}");

Paul Hendricks's avatar
Paul Hendricks committed
322
323
                        #[allow(deprecated)]
                        let inner = async_openai::types::CreateChatCompletionStreamResponse{
324
                            id: c.id,
Paul Hendricks's avatar
Paul Hendricks committed
325
                            choices: vec![async_openai::types::ChatChoiceStream{
326
                                index: 0,
Paul Hendricks's avatar
Paul Hendricks committed
327
                                delta: async_openai::types::ChatCompletionStreamResponseDelta{
328
                                    //role: c.choices[0].delta.role,
Paul Hendricks's avatar
Paul Hendricks committed
329
                                    role: Some(async_openai::types::Role::Assistant),
330
331
                                    content: Some(from_assistant),
                                    tool_calls: None,
Paul Hendricks's avatar
Paul Hendricks committed
332
333
                                    refusal: None,
                                    function_call: None,
334
335
336
337
338
                                },
                                logprobs: None,
                                finish_reason,
                            }],
                            model: c.model,
Paul Hendricks's avatar
Paul Hendricks committed
339
                            created: c.created as u32,
340
341
342
343
344
                            object: c.object.clone(),
                            usage: None,
                            system_fingerprint: Some(c.system_fingerprint),
                            service_tier: None,
                        };
345
                        let delta = NvCreateChatCompletionStreamResponse{inner};
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
                        let ann = Annotated{
                            id: None,
                            data: Some(delta),
                            event: None,
                            comment: None,
                        };
                        yield ann;

                        if finish_reason.is_some() {
                            //tracing::trace!("Finish reason: {finish_reason:?}");
                            break;
                        }
                    },
                    x => tracing::error!("Unhandled. {x:?}"),
                }
            }
        };
        Ok(ResponseStream::new(Box::pin(output), ctx))
    }
}
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396

/// openai stop tokens to mistralrs stop tokens
fn to_stop_tokens(t: async_openai::types::Stop) -> StopTokens {
    match t {
        async_openai::types::Stop::String(s) => StopTokens::Seqs(vec![s]),
        async_openai::types::Stop::StringArray(v) => StopTokens::Seqs(v),
    }
}

/// openai logit bias (strings/json) to mistralrs (u32/f32)
/// I think the input looks like this: {"3721": -100, "17765": 100}
fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
    let mut out = HashMap::new();
    for (key, value) in &lb {
        let token_id: u32 = match key.parse() {
            Ok(t) => t,
            Err(err) => {
                tracing::warn!(
                    "Unexpected logit_bias map. Key '{key}' is not an int: {lb:?}. {err}."
                );
                return HashMap::new();
            }
        };
        let Some(bias) = value.as_f64() else {
            tracing::warn!("Unexpected logit_bias map. Value '{value}' is not a float: {lb:?}");
            return HashMap::new();
        };
        out.insert(token_id, bias as f32);
    }
    out
}