lib.rs 23.8 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::{num::NonZero, sync::Arc};
6

Paul Hendricks's avatar
Paul Hendricks committed
7
use async_openai::types::FinishReason;
8
9
10
11
12
use async_stream::stream;
use async_trait::async_trait;
use either::Either;
use indexmap::IndexMap;
use mistralrs::{
13
    AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
14
    GGUFLoaderBuilder, GGUFSpecificConfig, IsqType, MemoryGpuConfig, MistralRs, MistralRsBuilder,
15
    ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
16
17
    PagedCacheType, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig,
    StopTokens, TokenSource, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
18
19
20
};
use tokio::sync::mpsc::channel;

Neelay Shah's avatar
Neelay Shah committed
21
22
23
24
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::error as pipeline_error;
use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
use dynamo_runtime::protocols::annotated::Annotated;
25

26
27
use dynamo_llm::protocols::openai::{
    chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
28
    completions::{prompt_to_string, NvCreateCompletionRequest, NvCreateCompletionResponse},
29
    embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
30
};
31
32

use dynamo_llm::engines::{EngineDispatcher, StreamingEngine};
33
use dynamo_llm::local_model::LocalModel;
34

35
36
37
38
39
40
41
42
43
/// How many requests mistral will run at once in the paged attention scheduler.
/// It actually runs 1 fewer than this.
/// I would call this the batch size but apparently that's something else.
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 10;

/// Experimental: Switch this to true to enable paged attention on CUDA devices.
/// Under load (dynamo-run batch mode) paged attention sometimes returns an immediate
/// finish_reason=stop and no tokens for one of the requests.
const EXP_ENABLE_PAGED_ATTENTION: bool = false;
44

45
46
47
48
49
/// Initial message we send to mistral.rs to warm it up. We may not need this.
const WARMUP_MESSAGE: &str = "This is a test message. Respond only with 'OK'.";

pub async fn make_engine(model: &LocalModel) -> pipeline_error::Result<Arc<dyn StreamingEngine>> {
    let engine = MistralRsEngine::new(model).await?;
50
    let engine: Arc<dyn StreamingEngine> = Arc::new(EngineDispatcher::new(engine));
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    Ok(engine)
}

/// Gets the best device, cpu, cuda if compiled with CUDA
fn best_device() -> pipeline_error::Result<Device> {
    #[cfg(not(feature = "metal"))]
    {
        Ok(Device::cuda_if_available(0)?)
    }
    #[cfg(feature = "metal")]
    {
        Ok(Device::new_metal(0)?)
    }
}

struct MistralRsEngine {
    mistralrs: Arc<MistralRs>,
68
    context_length: usize,
69
    display_name: String,
70
71
72
}

impl MistralRsEngine {
73
74
75
76
77
78
79
80
    async fn new(model: &LocalModel) -> pipeline_error::Result<Self> {
        let model_path = model.path();
        // Name some None's for clarity
        let chat_template = None;
        let tokenizer_json = None;
        let no_kv_cache = false;
        let jinja_explicit = None;
        let display_name = model.display_name();
81
82
83
84
85
86
87
88
        let loader = if model_path.is_file() {
            // Load from a GGUF
            let Some(model_filename) = model_path.file_name() else {
                pipeline_error::bail!("Missing filename in model path");
            };
            let Some(model_dir) = model_path.parent() else {
                pipeline_error::bail!("Invalid model path");
            };
89

90
            GGUFLoaderBuilder::new(
91
                chat_template,
92
93
94
95
96
97
98
                None,
                model_dir.display().to_string(),
                vec![model_filename.to_string_lossy().into_owned()],
                GGUFSpecificConfig {
                    prompt_chunksize: None,
                    topology: None,
                },
99
100
                no_kv_cache,
                jinja_explicit,
101
102
            )
            .build()
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        } else if is_vision_model(display_name) {
            let vlt = if is_gemma3(display_name) {
                VisionLoaderType::Gemma3
            } else if is_llama4(display_name) {
                VisionLoaderType::Llama4
            } else {
                panic!("Unsupported vision model {display_name}");
            };
            VisionLoaderBuilder::new(
                VisionSpecificConfig::default(),
                chat_template,
                tokenizer_json,
                Some(model_path.display().to_string()),
                jinja_explicit,
            )
118
            .build(Some(vlt))
119
120
121
        } else {
            // Load from a HF repo dir
            NormalLoaderBuilder::new(
122
123
124
                NormalSpecificConfig::default(),
                chat_template,
                tokenizer_json,
125
                Some(model_path.display().to_string()),
126
127
                no_kv_cache,
                jinja_explicit,
128
129
130
            )
            .build(None)?
        };
131

132
        let mut max_seq_len = model.card().context_length as usize;
133
134
135
136
        if max_seq_len == 0 {
            tracing::info!("context_length is 0. Probably error reading from model.");
            max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
        }
137

138
        // Paged attention requires cuda
139
        let paged_attention_config = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
140
            Some(PagedAttentionConfig::new(
141
                None, // Block size, default 32
142
                4096, // CPU memory in MiB
143
                MemoryGpuConfig::ContextSize(max_seq_len),
144
                PagedCacheType::Auto,
145
146
147
148
            )?)
        } else {
            None
        };
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

        let device_map_params = if is_vision_model(model.display_name()) {
            AutoDeviceMapParams::Vision {
                max_seq_len,
                max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
                max_image_shape: (0, 0),
                max_num_images: 0,
            }
        } else {
            AutoDeviceMapParams::Text {
                max_seq_len,
                max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
            }
        };

164
165
166
        // Load, into a Pipeline
        let pipeline = loader.load_model_from_hf(
            None,
167
            TokenSource::None, // The model was already downloaded
168
169
170
            &ModelDType::Auto,
            &best_device()?,
            false,
171
172
173
174
175
176
            DeviceMapSetting::Auto(device_map_params),
            if is_llama4(display_name) {
                Some(IsqType::Q4K)
            } else {
                None
            },
177
178
            paged_attention_config,
        )?;
179
        let scheduler = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
180
181
182
183
184
185
186
187
            tracing::debug!("Using mistralrs PagedAttentionMeta scheduler");
            let config = match pipeline.lock().await.get_metadata().cache_config.as_ref() {
                Some(conf) => conf.clone(),
                None => {
                    anyhow::bail!("Failed loading model config");
                }
            };
            SchedulerConfig::PagedAttentionMeta {
188
                max_num_seqs: PAGED_ATTENTION_MAX_NUM_SEQS,
189
190
191
192
193
                config,
            }
        } else {
            SchedulerConfig::DefaultScheduler {
                // Safety: unwrap trivially safe here
194
                method: DefaultSchedulerMethod::Fixed(NonZero::new(max_seq_len).unwrap()),
195
196
197
            }
        };
        // Create the MistralRs, which is a runner
198
199
200
201
202
203
204
205
206
        let throughput_logging = false;
        let search_embedding_model = None;
        let builder = MistralRsBuilder::new(
            pipeline.clone(),
            scheduler,
            throughput_logging,
            search_embedding_model,
        )
        .with_prefix_cache_n(16);
207
        let engine = MistralRsEngine {
208
            mistralrs: builder.build().await,
209
            context_length: max_seq_len,
210
            display_name: display_name.to_string(),
211
        };
212

213
214
        // skip the id used for dummy run https://github.com/EricLBuehler/mistral.rs/issues/1218
        let _ = engine.mistralrs.next_request_id();
215
216
217
218

        // Perform warmup request
        let (tx, mut rx) = channel(1);
        let request_id = engine.mistralrs.next_request_id();
219
        let warmup_request = Request::Normal(Box::new(NormalRequest {
220
            id: request_id,
221
            model_id: Some(display_name.to_string()),
222
223
224
225
226
227
228
229
230
231
            messages: RequestMessage::Chat {
                messages: vec![IndexMap::from([
                    ("role".to_string(), Either::Left("user".to_string())),
                    (
                        "content".to_string(),
                        Either::Left(WARMUP_MESSAGE.to_string()),
                    ),
                ])],
                enable_thinking: Some(false),
            },
232
233
234
235
236
237
238
239
240
241
            sampling_params: SamplingParams::deterministic(),
            response: tx,
            return_logprobs: false,
            is_streaming: false,
            constraint: Constraint::None,
            suffix: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
242
            web_search_options: None,
243
        }));
244
245

        // Send warmup request and consume response
246
        if let Ok(sender) = engine.mistralrs.get_sender(None) {
247
248
249
250
251
252
253
254
255
256
257
258
259
260
            if let Ok(()) = sender.send(warmup_request).await {
                if let Some(response) = rx.recv().await {
                    match response.as_result() {
                        Ok(r) => {
                            tracing::debug!(request_id, "Warmup response: {r:?}");
                        }
                        Err(err) => {
                            tracing::error!(request_id, %err, "Failed converting response to result.");
                        }
                    }
                }
            }
        }

261
        Ok(engine)
262
263
264
265
266
267
    }
}

#[async_trait]
impl
    AsyncEngine<
268
        SingleIn<NvCreateChatCompletionRequest>,
269
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
270
271
272
273
274
        Error,
    > for MistralRsEngine
{
    async fn generate(
        &self,
275
        request: SingleIn<NvCreateChatCompletionRequest>,
276
    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
277
278
279
280
281
        let (request, context) = request.transfer(());
        let ctx = context.context();
        let (tx, mut rx) = channel(10_000);

        let mut messages = vec![];
Paul Hendricks's avatar
Paul Hendricks committed
282
283
284
285
        for m in request.inner.messages {
            let async_openai::types::ChatCompletionRequestMessage::User(inner_m) = m else {
                continue;
            };
286
287
288
289
            let async_openai::types::ChatCompletionRequestUserMessageContent::Text(content) =
                inner_m.content
            else {
                anyhow::bail!("Only Text type chat completion supported");
290
291
            };
            let r = IndexMap::from([
Paul Hendricks's avatar
Paul Hendricks committed
292
                ("role".to_string(), Either::Left("user".to_string())),
293
294
295
296
297
298
299
300
                ("content".to_string(), Either::Left(content)),
            ]);
            messages.push(r);
        }
        if messages.is_empty() {
            anyhow::bail!("Empty request");
        }

301
        let det = SamplingParams::deterministic();
302
303
        // allow deprecated because max_tokens
        #[allow(deprecated)]
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        let sampling_params = SamplingParams {
            temperature: request
                .inner
                .temperature
                .map(|t| t as f64)
                .or(det.temperature),
            top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
            top_n_logprobs: request
                .inner
                .top_logprobs
                .map(|t| t as usize)
                .unwrap_or(det.top_n_logprobs),
            frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
            presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
            stop_toks: request.inner.stop.map(to_stop_tokens).or(det.stop_toks),
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
            max_len: {
                let requested_max_tokens = request
                    .inner
                    .max_completion_tokens
                    .or(request.inner.max_tokens)
                    .map(|m| m as usize);

                // Ensure max_len doesn't exceed context length
                match requested_max_tokens {
                    Some(max_tokens) => Some(std::cmp::min(max_tokens, self.context_length)),
                    None => det
                        .max_len
                        .map(|len| std::cmp::min(len, self.context_length)),
                }
            },
334
335
336
337
338
339
340
341
342
343
344
            logits_bias: request
                .inner
                .logit_bias
                .map(to_logit_bias)
                .or(det.logits_bias),
            // These are not in async-openai yet
            top_k: det.top_k,
            min_p: det.min_p,
            n_choices: 1,
            dry_params: det.dry_params,
        };
345
        let request_id = self.mistralrs.next_request_id();
346
        let mistralrs_request = Request::Normal(Box::new(NormalRequest {
347
            id: request_id,
348
            model_id: Some(self.display_name.clone()),
349
350
351
352
            messages: RequestMessage::Chat {
                messages,
                enable_thinking: None,
            },
353
            sampling_params,
354
            response: tx,
355
            return_logprobs: request.inner.logprobs.unwrap_or_default(),
356
357
358
359
360
361
362
            is_streaming: true,
            constraint: Constraint::None,
            suffix: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
363
            web_search_options: None,
364
        }));
365

366
367
368
369
        self.mistralrs
            .get_sender(None)?
            .send(mistralrs_request)
            .await?;
370
371
372
373
374
375

        let output = stream! {
            while let Some(response) = rx.recv().await {
                let response = match response.as_result() {
                    Ok(r) => r,
                    Err(err) => {
376
                        tracing::error!(request_id, %err, "Failed converting mistralrs channel response to result.");
377
378
379
380
381
                        break;
                    }
                };
                match response {
                    ResponseOk::Chunk(c) => {
382
                        let Some(from_assistant) = c.choices[0].delta.content.clone() else {
383
                            tracing::warn!(request_id, "No content from mistralrs. Abandoning request.");
384
385
                            break;
                        };
386
387
388
389
390
                        let finish_reason = match &c.choices[0].finish_reason.as_deref() {
                            Some("stop") | Some("canceled") => {
                                Some(FinishReason::Stop)
                            }
                            Some("length") => {
Paul Hendricks's avatar
Paul Hendricks committed
391
                                Some(FinishReason::Length)
392
                            }
393
                            Some(s) => {
394
                                tracing::warn!(request_id, stop_reason = s, "Unknow stop reason");
395
396
                                Some(FinishReason::Stop)
                            }
397
398
399
400
                            None => None,
                        };
                        //tracing::trace!("from_assistant: {from_assistant}");

Paul Hendricks's avatar
Paul Hendricks committed
401
402
                        #[allow(deprecated)]
                        let inner = async_openai::types::CreateChatCompletionStreamResponse{
403
                            id: c.id,
Paul Hendricks's avatar
Paul Hendricks committed
404
                            choices: vec![async_openai::types::ChatChoiceStream{
405
                                index: 0,
Paul Hendricks's avatar
Paul Hendricks committed
406
                                delta: async_openai::types::ChatCompletionStreamResponseDelta{
407
                                    //role: c.choices[0].delta.role,
Paul Hendricks's avatar
Paul Hendricks committed
408
                                    role: Some(async_openai::types::Role::Assistant),
409
410
                                    content: Some(from_assistant),
                                    tool_calls: None,
Paul Hendricks's avatar
Paul Hendricks committed
411
412
                                    refusal: None,
                                    function_call: None,
413
414
415
416
417
                                },
                                logprobs: None,
                                finish_reason,
                            }],
                            model: c.model,
Paul Hendricks's avatar
Paul Hendricks committed
418
                            created: c.created as u32,
419
420
421
422
423
                            object: c.object.clone(),
                            usage: None,
                            system_fingerprint: Some(c.system_fingerprint),
                            service_tier: None,
                        };
424
                        let delta = NvCreateChatCompletionStreamResponse{inner};
425
426
427
428
429
430
431
432
433
                        let ann = Annotated{
                            id: None,
                            data: Some(delta),
                            event: None,
                            comment: None,
                        };
                        yield ann;

                        if finish_reason.is_some() {
434
                            //tracing::trace!(request_id, "Finish reason: {finish_reason:?}");
435
436
437
                            break;
                        }
                    },
438
                    x => tracing::error!(request_id, "Unhandled. {x:?}"),
439
440
441
442
443
444
                }
            }
        };
        Ok(ResponseStream::new(Box::pin(output), ctx))
    }
}
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475

/// openai stop tokens to mistralrs stop tokens
fn to_stop_tokens(t: async_openai::types::Stop) -> StopTokens {
    match t {
        async_openai::types::Stop::String(s) => StopTokens::Seqs(vec![s]),
        async_openai::types::Stop::StringArray(v) => StopTokens::Seqs(v),
    }
}

/// openai logit bias (strings/json) to mistralrs (u32/f32)
/// I think the input looks like this: {"3721": -100, "17765": 100}
fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
    let mut out = HashMap::new();
    for (key, value) in &lb {
        let token_id: u32 = match key.parse() {
            Ok(t) => t,
            Err(err) => {
                tracing::warn!(
                    "Unexpected logit_bias map. Key '{key}' is not an int: {lb:?}. {err}."
                );
                return HashMap::new();
            }
        };
        let Some(bias) = value.as_f64() else {
            tracing::warn!("Unexpected logit_bias map. Value '{value}' is not a float: {lb:?}");
            return HashMap::new();
        };
        out.insert(token_id, bias as f32);
    }
    out
}
476
477

#[async_trait]
478
479
480
481
482
483
impl
    AsyncEngine<
        SingleIn<NvCreateCompletionRequest>,
        ManyOut<Annotated<NvCreateCompletionResponse>>,
        Error,
    > for MistralRsEngine
484
485
486
{
    async fn generate(
        &self,
487
        request: SingleIn<NvCreateCompletionRequest>,
488
    ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
        let (request, context) = request.transfer(());
        let ctx = context.context();
        let (tx, mut rx) = channel(10_000);
        let response_generator = request.response_generator();

        let messages = RequestMessage::Completion {
            text: prompt_to_string(&request.inner.prompt),
            echo_prompt: false,
            best_of: Some(1),
        };
        let det = SamplingParams::deterministic();
        // allow deprecated because max_tokens
        #[allow(deprecated)]
        let sampling_params = SamplingParams {
            temperature: request
                .inner
                .temperature
                .map(|t| t as f64)
                .or(det.temperature),
            top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
            top_n_logprobs: request
                .inner
                .logprobs
                .map(|t| t as usize)
                .unwrap_or(det.top_n_logprobs),
            frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
            presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
            stop_toks: request
                .inner
                .stop
                .clone()
                .map(to_stop_tokens)
                .or(det.stop_toks),
522
523
524
525
526
527
528
529
530
531
532
            max_len: {
                let requested_max_tokens = request.inner.max_tokens.map(|m| m as usize);

                // Ensure max_len doesn't exceed context length
                match requested_max_tokens {
                    Some(max_tokens) => Some(std::cmp::min(max_tokens, self.context_length)),
                    None => det
                        .max_len
                        .map(|len| std::cmp::min(len, self.context_length)),
                }
            },
533
534
535
536
537
538
539
540
541
542
543
544
545
546
            logits_bias: request
                .inner
                .logit_bias
                .clone()
                .map(to_logit_bias)
                .or(det.logits_bias),
            // These are not in async-openai yet
            top_k: det.top_k,
            min_p: det.min_p,
            n_choices: 1,
            dry_params: det.dry_params,
        };

        let request_id = self.mistralrs.next_request_id();
547
        let mistralrs_request = Request::Normal(Box::new(NormalRequest {
548
            id: request_id,
549
            model_id: Some(self.display_name.clone()),
550
551
552
553
554
555
556
557
558
559
560
            messages,
            sampling_params,
            response: tx,
            return_logprobs: false,
            is_streaming: true,
            constraint: Constraint::None,
            suffix: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
561
            web_search_options: None,
562
        }));
563

564
565
566
567
        self.mistralrs
            .get_sender(None)?
            .send(mistralrs_request)
            .await?;
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615

        let output = stream! {
            while let Some(response) = rx.recv().await {
                let response = match response.as_result() {
                    Ok(r) => r,
                    Err(err) => {
                        tracing::error!(request_id, %err, "Failed converting mistralrs channel response to result.");
                        break;
                    }
                };
                match response {
                    ResponseOk::CompletionChunk(c) => {
                        let from_assistant = c.choices[0].text.clone();

                        let finish_reason = match &c.choices[0].finish_reason.as_deref() {
                            Some("stop") | Some("canceled") => {
                                Some(FinishReason::Stop)
                            }
                            Some("length") => {
                                Some(FinishReason::Length)
                            }
                            Some(s) => {
                                tracing::warn!(request_id, stop_reason = s, "Unknow stop reason");
                                Some(FinishReason::Stop)
                            }
                            None => None,
                        };
                        #[allow(deprecated)]
                        let inner = response_generator.create_choice(0, Some(from_assistant), None);
                        let ann = Annotated{
                            id: None,
                            data: Some(inner),
                            event: None,
                            comment: None,
                        };
                        yield ann;

                        if finish_reason.is_some() {
                            break;
                        }
                    },
                    x => tracing::error!(request_id, "Unhandled. {x:?}"),
                }
            }
        };
        Ok(ResponseStream::new(Box::pin(output), ctx))
    }
}
616
617
618
619
620
621
622
623
624
625
626
627

fn is_vision_model(s: &str) -> bool {
    is_gemma3(s) || is_llama4(s)
}

fn is_gemma3(s: &str) -> bool {
    s.to_lowercase().contains("gemma-3")
}

fn is_llama4(s: &str) -> bool {
    s.to_lowercase().contains("llama-4")
}
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643

#[async_trait]
impl
    AsyncEngine<
        SingleIn<NvCreateEmbeddingRequest>,
        ManyOut<Annotated<NvCreateEmbeddingResponse>>,
        Error,
    > for MistralRsEngine
{
    async fn generate(
        &self,
        _request: SingleIn<NvCreateEmbeddingRequest>,
    ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
        unimplemented!()
    }
}