lib.rs 24.1 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::{num::NonZero, sync::Arc};
6
7
8

use async_stream::stream;
use async_trait::async_trait;
9
use dynamo_async_openai::types::FinishReason;
10
11
12
use either::Either;
use indexmap::IndexMap;
use mistralrs::{
13
    AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
14
    GGUFLoaderBuilder, GGUFSpecificConfig, IsqType, MemoryGpuConfig, MistralRs, MistralRsBuilder,
15
    ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
16
17
    PagedCacheType, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig,
    StopTokens, TokenSource, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
18
19
20
};
use tokio::sync::mpsc::channel;

Neelay Shah's avatar
Neelay Shah committed
21
22
23
24
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::error as pipeline_error;
use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
use dynamo_runtime::protocols::annotated::Annotated;
25

26
27
use dynamo_llm::protocols::openai::{
    chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
28
    completions::{NvCreateCompletionRequest, NvCreateCompletionResponse, prompt_to_string},
29
    embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
30
};
31
32

use dynamo_llm::engines::{EngineDispatcher, StreamingEngine};
33
use dynamo_llm::local_model::LocalModel;
34

35
36
37
38
39
40
41
42
43
/// How many requests mistral will run at once in the paged attention scheduler.
/// It actually runs 1 fewer than this.
/// I would call this the batch size but apparently that's something else.
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 10;

/// Experimental: Switch this to true to enable paged attention on CUDA devices.
/// Under load (dynamo-run batch mode) paged attention sometimes returns an immediate
/// finish_reason=stop and no tokens for one of the requests.
const EXP_ENABLE_PAGED_ATTENTION: bool = false;
44

45
46
47
48
49
/// Initial message we send to mistral.rs to warm it up. We may not need this.
const WARMUP_MESSAGE: &str = "This is a test message. Respond only with 'OK'.";

pub async fn make_engine(model: &LocalModel) -> pipeline_error::Result<Arc<dyn StreamingEngine>> {
    let engine = MistralRsEngine::new(model).await?;
50
    let engine: Arc<dyn StreamingEngine> = Arc::new(EngineDispatcher::new(engine));
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    Ok(engine)
}

/// Gets the best device, cpu, cuda if compiled with CUDA
fn best_device() -> pipeline_error::Result<Device> {
    #[cfg(not(feature = "metal"))]
    {
        Ok(Device::cuda_if_available(0)?)
    }
    #[cfg(feature = "metal")]
    {
        Ok(Device::new_metal(0)?)
    }
}

struct MistralRsEngine {
    mistralrs: Arc<MistralRs>,
68
    context_length: usize,
69
    display_name: String,
70
71
72
}

impl MistralRsEngine {
73
74
75
76
77
78
79
80
    async fn new(model: &LocalModel) -> pipeline_error::Result<Self> {
        let model_path = model.path();
        // Name some None's for clarity
        let chat_template = None;
        let tokenizer_json = None;
        let no_kv_cache = false;
        let jinja_explicit = None;
        let display_name = model.display_name();
81
82
83
84
85
86
87
88
        let loader = if model_path.is_file() {
            // Load from a GGUF
            let Some(model_filename) = model_path.file_name() else {
                pipeline_error::bail!("Missing filename in model path");
            };
            let Some(model_dir) = model_path.parent() else {
                pipeline_error::bail!("Invalid model path");
            };
89

90
            GGUFLoaderBuilder::new(
91
                chat_template,
92
93
94
                None,
                model_dir.display().to_string(),
                vec![model_filename.to_string_lossy().into_owned()],
95
                GGUFSpecificConfig::default(),
96
97
                no_kv_cache,
                jinja_explicit,
98
99
            )
            .build()
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        } else if is_vision_model(display_name) {
            let vlt = if is_gemma3(display_name) {
                VisionLoaderType::Gemma3
            } else if is_llama4(display_name) {
                VisionLoaderType::Llama4
            } else {
                panic!("Unsupported vision model {display_name}");
            };
            VisionLoaderBuilder::new(
                VisionSpecificConfig::default(),
                chat_template,
                tokenizer_json,
                Some(model_path.display().to_string()),
                jinja_explicit,
            )
115
            .build(Some(vlt))
116
117
118
        } else {
            // Load from a HF repo dir
            NormalLoaderBuilder::new(
119
120
121
                NormalSpecificConfig::default(),
                chat_template,
                tokenizer_json,
122
                Some(model_path.display().to_string()),
123
124
                no_kv_cache,
                jinja_explicit,
125
126
127
            )
            .build(None)?
        };
128

129
        let mut max_seq_len = model.card().context_length as usize;
130
131
132
133
        if max_seq_len == 0 {
            tracing::info!("context_length is 0. Probably error reading from model.");
            max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
        }
134

135
        // Paged attention requires cuda
136
        let paged_attention_config = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
137
            Some(PagedAttentionConfig::new(
138
                None, // Block size, default 32
139
                4096, // CPU memory in MiB
140
                MemoryGpuConfig::ContextSize(max_seq_len),
141
                PagedCacheType::Auto,
142
143
144
145
            )?)
        } else {
            None
        };
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

        let device_map_params = if is_vision_model(model.display_name()) {
            AutoDeviceMapParams::Vision {
                max_seq_len,
                max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
                max_image_shape: (0, 0),
                max_num_images: 0,
            }
        } else {
            AutoDeviceMapParams::Text {
                max_seq_len,
                max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
            }
        };

161
162
163
        // Load, into a Pipeline
        let pipeline = loader.load_model_from_hf(
            None,
164
            TokenSource::None, // The model was already downloaded
165
166
167
            &ModelDType::Auto,
            &best_device()?,
            false,
168
169
170
171
172
173
            DeviceMapSetting::Auto(device_map_params),
            if is_llama4(display_name) {
                Some(IsqType::Q4K)
            } else {
                None
            },
174
175
            paged_attention_config,
        )?;
176
        let scheduler = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
177
178
179
180
181
182
183
184
            tracing::debug!("Using mistralrs PagedAttentionMeta scheduler");
            let config = match pipeline.lock().await.get_metadata().cache_config.as_ref() {
                Some(conf) => conf.clone(),
                None => {
                    anyhow::bail!("Failed loading model config");
                }
            };
            SchedulerConfig::PagedAttentionMeta {
185
                max_num_seqs: PAGED_ATTENTION_MAX_NUM_SEQS,
186
187
188
189
190
                config,
            }
        } else {
            SchedulerConfig::DefaultScheduler {
                // Safety: unwrap trivially safe here
191
                method: DefaultSchedulerMethod::Fixed(NonZero::new(max_seq_len).unwrap()),
192
193
194
            }
        };
        // Create the MistralRs, which is a runner
195
196
197
198
199
200
201
202
203
        let throughput_logging = false;
        let search_embedding_model = None;
        let builder = MistralRsBuilder::new(
            pipeline.clone(),
            scheduler,
            throughput_logging,
            search_embedding_model,
        )
        .with_prefix_cache_n(16);
204
        let engine = MistralRsEngine {
205
            mistralrs: builder.build().await,
206
            context_length: max_seq_len,
207
            display_name: display_name.to_string(),
208
        };
209

210
211
        // skip the id used for dummy run https://github.com/EricLBuehler/mistral.rs/issues/1218
        let _ = engine.mistralrs.next_request_id();
212
213
214

        // Perform warmup request
        let (tx, mut rx) = channel(1);
215
        let mistralrs_request_id = engine.mistralrs.next_request_id();
216
        let warmup_request = Request::Normal(Box::new(NormalRequest {
217
            id: mistralrs_request_id,
218
            model_id: Some(display_name.to_string()),
219
220
221
222
223
224
225
226
227
228
            messages: RequestMessage::Chat {
                messages: vec![IndexMap::from([
                    ("role".to_string(), Either::Left("user".to_string())),
                    (
                        "content".to_string(),
                        Either::Left(WARMUP_MESSAGE.to_string()),
                    ),
                ])],
                enable_thinking: Some(false),
            },
229
230
231
232
233
234
235
236
237
238
            sampling_params: SamplingParams::deterministic(),
            response: tx,
            return_logprobs: false,
            is_streaming: false,
            constraint: Constraint::None,
            suffix: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
239
            web_search_options: None,
240
        }));
241
242

        // Send warmup request and consume response
243
244
245
246
247
248
        if let Ok(sender) = engine.mistralrs.get_sender(None)
            && let Ok(()) = sender.send(warmup_request).await
            && let Some(response) = rx.recv().await
        {
            match response.as_result() {
                Ok(r) => {
249
                    tracing::debug!(mistralrs_request_id, "Warmup response: {r:?}");
250
251
                }
                Err(err) => {
252
                    tracing::error!(mistralrs_request_id, %err, "Failed converting response to result.");
253
254
255
256
                }
            }
        }

257
        Ok(engine)
258
259
260
261
262
263
    }
}

#[async_trait]
impl
    AsyncEngine<
264
        SingleIn<NvCreateChatCompletionRequest>,
265
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
266
267
268
269
270
        Error,
    > for MistralRsEngine
{
    async fn generate(
        &self,
271
        request: SingleIn<NvCreateChatCompletionRequest>,
272
    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
273
274
        let (request, context) = request.transfer(());
        let ctx = context.context();
275
        let request_id = ctx.id().to_string();
276
277
278
        let (tx, mut rx) = channel(10_000);

        let mut messages = vec![];
Paul Hendricks's avatar
Paul Hendricks committed
279
        for m in request.inner.messages {
280
            let dynamo_async_openai::types::ChatCompletionRequestMessage::User(inner_m) = m else {
Paul Hendricks's avatar
Paul Hendricks committed
281
282
                continue;
            };
283
            let dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(content) =
284
285
286
                inner_m.content
            else {
                anyhow::bail!("Only Text type chat completion supported");
287
288
            };
            let r = IndexMap::from([
Paul Hendricks's avatar
Paul Hendricks committed
289
                ("role".to_string(), Either::Left("user".to_string())),
290
291
292
293
294
295
296
297
                ("content".to_string(), Either::Left(content)),
            ]);
            messages.push(r);
        }
        if messages.is_empty() {
            anyhow::bail!("Empty request");
        }

298
        let det = SamplingParams::deterministic();
299
300
        // allow deprecated because max_tokens
        #[allow(deprecated)]
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        let sampling_params = SamplingParams {
            temperature: request
                .inner
                .temperature
                .map(|t| t as f64)
                .or(det.temperature),
            top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
            top_n_logprobs: request
                .inner
                .top_logprobs
                .map(|t| t as usize)
                .unwrap_or(det.top_n_logprobs),
            frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
            presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
Graham King's avatar
Graham King committed
315
            repetition_penalty: det.repetition_penalty,
316
            stop_toks: request.inner.stop.map(to_stop_tokens).or(det.stop_toks),
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
            max_len: {
                let requested_max_tokens = request
                    .inner
                    .max_completion_tokens
                    .or(request.inner.max_tokens)
                    .map(|m| m as usize);

                // Ensure max_len doesn't exceed context length
                match requested_max_tokens {
                    Some(max_tokens) => Some(std::cmp::min(max_tokens, self.context_length)),
                    None => det
                        .max_len
                        .map(|len| std::cmp::min(len, self.context_length)),
                }
            },
332
333
334
335
336
337
338
339
340
341
342
            logits_bias: request
                .inner
                .logit_bias
                .map(to_logit_bias)
                .or(det.logits_bias),
            // These are not in async-openai yet
            top_k: det.top_k,
            min_p: det.min_p,
            n_choices: 1,
            dry_params: det.dry_params,
        };
343
        let mistralrs_request_id = self.mistralrs.next_request_id();
344
        let mistralrs_request = Request::Normal(Box::new(NormalRequest {
345
            id: mistralrs_request_id,
346
            model_id: Some(self.display_name.clone()),
347
348
349
350
            messages: RequestMessage::Chat {
                messages,
                enable_thinking: None,
            },
351
            sampling_params,
352
            response: tx,
353
            return_logprobs: request.inner.logprobs.unwrap_or_default(),
354
355
356
357
358
359
360
            is_streaming: true,
            constraint: Constraint::None,
            suffix: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
361
            web_search_options: None,
362
        }));
363

364
365
366
367
        self.mistralrs
            .get_sender(None)?
            .send(mistralrs_request)
            .await?;
368
369
370
371
372
373

        let output = stream! {
            while let Some(response) = rx.recv().await {
                let response = match response.as_result() {
                    Ok(r) => r,
                    Err(err) => {
374
                        tracing::error!(mistralrs_request_id, %err, "Failed converting mistralrs channel response to result.");
375
376
377
378
379
                        break;
                    }
                };
                match response {
                    ResponseOk::Chunk(c) => {
380
                        let Some(from_assistant) = c.choices[0].delta.content.clone() else {
381
                            tracing::warn!(mistralrs_request_id, "No content from mistralrs. Abandoning request.");
382
383
                            break;
                        };
384
385
386
387
388
                        let finish_reason = match &c.choices[0].finish_reason.as_deref() {
                            Some("stop") | Some("canceled") => {
                                Some(FinishReason::Stop)
                            }
                            Some("length") => {
Paul Hendricks's avatar
Paul Hendricks committed
389
                                Some(FinishReason::Length)
390
                            }
391
                            Some(s) => {
392
                                tracing::warn!(mistralrs_request_id, stop_reason = s, "Unknow stop reason");
393
394
                                Some(FinishReason::Stop)
                            }
395
396
397
398
                            None => None,
                        };
                        //tracing::trace!("from_assistant: {from_assistant}");

Paul Hendricks's avatar
Paul Hendricks committed
399
                        #[allow(deprecated)]
400
                        let delta = NvCreateChatCompletionStreamResponse {
401
                            id: format!("chatcmpl-{request_id}"),
402
                            choices: vec![dynamo_async_openai::types::ChatChoiceStream{
403
                                index: 0,
404
                                delta: dynamo_async_openai::types::ChatCompletionStreamResponseDelta{
405
                                    //role: c.choices[0].delta.role,
406
                                    role: Some(dynamo_async_openai::types::Role::Assistant),
407
408
                                    content: Some(from_assistant),
                                    tool_calls: None,
Paul Hendricks's avatar
Paul Hendricks committed
409
410
                                    refusal: None,
                                    function_call: None,
411
                                    reasoning_content: None,
412
413
414
415
416
                                },
                                logprobs: None,
                                finish_reason,
                            }],
                            model: c.model,
Paul Hendricks's avatar
Paul Hendricks committed
417
                            created: c.created as u32,
418
419
420
421
                            object: c.object.clone(),
                            usage: None,
                            system_fingerprint: Some(c.system_fingerprint),
                            service_tier: None,
422
                            nvext: None,
423
424
425
426
427
428
429
430
431
432
                        };
                        let ann = Annotated{
                            id: None,
                            data: Some(delta),
                            event: None,
                            comment: None,
                        };
                        yield ann;

                        if finish_reason.is_some() {
433
                            //tracing::trace!(mistralrs_request_id, "Finish reason: {finish_reason:?}");
434
435
436
                            break;
                        }
                    },
437
                    x => tracing::error!(mistralrs_request_id, "Unhandled. {x:?}"),
438
439
440
441
442
443
                }
            }
        };
        Ok(ResponseStream::new(Box::pin(output), ctx))
    }
}
444
445

/// openai stop tokens to mistralrs stop tokens
446
fn to_stop_tokens(t: dynamo_async_openai::types::Stop) -> StopTokens {
447
    match t {
448
449
        dynamo_async_openai::types::Stop::String(s) => StopTokens::Seqs(vec![s]),
        dynamo_async_openai::types::Stop::StringArray(v) => StopTokens::Seqs(v),
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
    }
}

/// openai logit bias (strings/json) to mistralrs (u32/f32)
/// I think the input looks like this: {"3721": -100, "17765": 100}
fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
    let mut out = HashMap::new();
    for (key, value) in &lb {
        let token_id: u32 = match key.parse() {
            Ok(t) => t,
            Err(err) => {
                tracing::warn!(
                    "Unexpected logit_bias map. Key '{key}' is not an int: {lb:?}. {err}."
                );
                return HashMap::new();
            }
        };
        let Some(bias) = value.as_f64() else {
            tracing::warn!("Unexpected logit_bias map. Value '{value}' is not a float: {lb:?}");
            return HashMap::new();
        };
        out.insert(token_id, bias as f32);
    }
    out
}
475
476

#[async_trait]
477
478
479
480
481
482
impl
    AsyncEngine<
        SingleIn<NvCreateCompletionRequest>,
        ManyOut<Annotated<NvCreateCompletionResponse>>,
        Error,
    > for MistralRsEngine
483
484
485
{
    async fn generate(
        &self,
486
        request: SingleIn<NvCreateCompletionRequest>,
487
    ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
488
489
490
        let (request, context) = request.transfer(());
        let ctx = context.context();
        let (tx, mut rx) = channel(10_000);
491
        let response_generator = request.response_generator(ctx.id().to_string());
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514

        let messages = RequestMessage::Completion {
            text: prompt_to_string(&request.inner.prompt),
            echo_prompt: false,
            best_of: Some(1),
        };
        let det = SamplingParams::deterministic();
        // allow deprecated because max_tokens
        #[allow(deprecated)]
        let sampling_params = SamplingParams {
            temperature: request
                .inner
                .temperature
                .map(|t| t as f64)
                .or(det.temperature),
            top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
            top_n_logprobs: request
                .inner
                .logprobs
                .map(|t| t as usize)
                .unwrap_or(det.top_n_logprobs),
            frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
            presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
Graham King's avatar
Graham King committed
515
            repetition_penalty: det.repetition_penalty,
516
517
518
519
520
521
            stop_toks: request
                .inner
                .stop
                .clone()
                .map(to_stop_tokens)
                .or(det.stop_toks),
522
523
524
525
526
527
528
529
530
531
532
            max_len: {
                let requested_max_tokens = request.inner.max_tokens.map(|m| m as usize);

                // Ensure max_len doesn't exceed context length
                match requested_max_tokens {
                    Some(max_tokens) => Some(std::cmp::min(max_tokens, self.context_length)),
                    None => det
                        .max_len
                        .map(|len| std::cmp::min(len, self.context_length)),
                }
            },
533
534
535
536
537
538
539
540
541
542
543
544
545
            logits_bias: request
                .inner
                .logit_bias
                .clone()
                .map(to_logit_bias)
                .or(det.logits_bias),
            // These are not in async-openai yet
            top_k: det.top_k,
            min_p: det.min_p,
            n_choices: 1,
            dry_params: det.dry_params,
        };

546
        let mistralrs_request_id = self.mistralrs.next_request_id();
547
        let mistralrs_request = Request::Normal(Box::new(NormalRequest {
548
            id: mistralrs_request_id,
549
            model_id: Some(self.display_name.clone()),
550
551
552
553
554
555
556
557
558
559
560
            messages,
            sampling_params,
            response: tx,
            return_logprobs: false,
            is_streaming: true,
            constraint: Constraint::None,
            suffix: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
561
            web_search_options: None,
562
        }));
563

564
565
566
567
        self.mistralrs
            .get_sender(None)?
            .send(mistralrs_request)
            .await?;
568
569
570
571
572
573

        let output = stream! {
            while let Some(response) = rx.recv().await {
                let response = match response.as_result() {
                    Ok(r) => r,
                    Err(err) => {
574
                        tracing::error!(mistralrs_request_id, %err, "Failed converting mistralrs channel response to result.");
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
                        break;
                    }
                };
                match response {
                    ResponseOk::CompletionChunk(c) => {
                        let from_assistant = c.choices[0].text.clone();

                        let finish_reason = match &c.choices[0].finish_reason.as_deref() {
                            Some("stop") | Some("canceled") => {
                                Some(FinishReason::Stop)
                            }
                            Some("length") => {
                                Some(FinishReason::Length)
                            }
                            Some(s) => {
590
                                tracing::warn!(mistralrs_request_id, stop_reason = s, "Unknow stop reason");
591
592
593
594
595
                                Some(FinishReason::Stop)
                            }
                            None => None,
                        };
                        #[allow(deprecated)]
Greg Clark's avatar
Greg Clark committed
596
                        let inner = response_generator.create_choice(0, Some(from_assistant), None, None);
597
598
599
600
601
602
603
604
605
606
607
608
                        let ann = Annotated{
                            id: None,
                            data: Some(inner),
                            event: None,
                            comment: None,
                        };
                        yield ann;

                        if finish_reason.is_some() {
                            break;
                        }
                    },
609
                    x => tracing::error!(mistralrs_request_id, "Unhandled. {x:?}"),
610
611
612
613
614
615
                }
            }
        };
        Ok(ResponseStream::new(Box::pin(output), ctx))
    }
}
616
617
618
619
620
621
622
623
624
625
626
627

fn is_vision_model(s: &str) -> bool {
    is_gemma3(s) || is_llama4(s)
}

fn is_gemma3(s: &str) -> bool {
    s.to_lowercase().contains("gemma-3")
}

fn is_llama4(s: &str) -> bool {
    s.to_lowercase().contains("llama-4")
}
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643

#[async_trait]
impl
    AsyncEngine<
        SingleIn<NvCreateEmbeddingRequest>,
        ManyOut<Annotated<NvCreateEmbeddingResponse>>,
        Error,
    > for MistralRsEngine
{
    async fn generate(
        &self,
        _request: SingleIn<NvCreateEmbeddingRequest>,
    ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
        unimplemented!()
    }
}