lib.rs 23.8 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use std::collections::HashMap;
5
use std::{num::NonZero, sync::Arc};
6
7
8

use async_stream::stream;
use async_trait::async_trait;
9
use dynamo_async_openai::types::FinishReason;
10
11
12
use either::Either;
use indexmap::IndexMap;
use mistralrs::{
13
    AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
14
    GGUFLoaderBuilder, GGUFSpecificConfig, IsqType, MemoryGpuConfig, MistralRs, MistralRsBuilder,
15
    ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
16
17
    PagedCacheType, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig,
    StopTokens, TokenSource, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
18
19
20
};
use tokio::sync::mpsc::channel;

Neelay Shah's avatar
Neelay Shah committed
21
22
23
24
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::error as pipeline_error;
use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
use dynamo_runtime::protocols::annotated::Annotated;
25

26
27
use dynamo_llm::protocols::openai::{
    chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
28
    completions::{prompt_to_string, NvCreateCompletionRequest, NvCreateCompletionResponse},
29
    embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
30
};
31
32

use dynamo_llm::engines::{EngineDispatcher, StreamingEngine};
33
use dynamo_llm::local_model::LocalModel;
34

35
36
37
38
39
40
41
42
43
/// How many requests mistral will run at once in the paged attention scheduler.
/// It actually runs 1 fewer than this.
/// I would call this the batch size but apparently that's something else.
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 10;

/// Experimental: Switch this to true to enable paged attention on CUDA devices.
/// Under load (dynamo-run batch mode) paged attention sometimes returns an immediate
/// finish_reason=stop and no tokens for one of the requests.
const EXP_ENABLE_PAGED_ATTENTION: bool = false;
44

45
46
47
48
49
/// Initial message we send to mistral.rs to warm it up. We may not need this.
const WARMUP_MESSAGE: &str = "This is a test message. Respond only with 'OK'.";

pub async fn make_engine(model: &LocalModel) -> pipeline_error::Result<Arc<dyn StreamingEngine>> {
    let engine = MistralRsEngine::new(model).await?;
50
    let engine: Arc<dyn StreamingEngine> = Arc::new(EngineDispatcher::new(engine));
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    Ok(engine)
}

/// Gets the best device, cpu, cuda if compiled with CUDA
fn best_device() -> pipeline_error::Result<Device> {
    #[cfg(not(feature = "metal"))]
    {
        Ok(Device::cuda_if_available(0)?)
    }
    #[cfg(feature = "metal")]
    {
        Ok(Device::new_metal(0)?)
    }
}

struct MistralRsEngine {
    mistralrs: Arc<MistralRs>,
68
    context_length: usize,
69
    display_name: String,
70
71
72
}

impl MistralRsEngine {
73
74
75
76
77
78
79
80
    async fn new(model: &LocalModel) -> pipeline_error::Result<Self> {
        let model_path = model.path();
        // Name some None's for clarity
        let chat_template = None;
        let tokenizer_json = None;
        let no_kv_cache = false;
        let jinja_explicit = None;
        let display_name = model.display_name();
81
82
83
84
85
86
87
88
        let loader = if model_path.is_file() {
            // Load from a GGUF
            let Some(model_filename) = model_path.file_name() else {
                pipeline_error::bail!("Missing filename in model path");
            };
            let Some(model_dir) = model_path.parent() else {
                pipeline_error::bail!("Invalid model path");
            };
89

90
            GGUFLoaderBuilder::new(
91
                chat_template,
92
93
94
                None,
                model_dir.display().to_string(),
                vec![model_filename.to_string_lossy().into_owned()],
95
                GGUFSpecificConfig::default(),
96
97
                no_kv_cache,
                jinja_explicit,
98
99
            )
            .build()
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        } else if is_vision_model(display_name) {
            let vlt = if is_gemma3(display_name) {
                VisionLoaderType::Gemma3
            } else if is_llama4(display_name) {
                VisionLoaderType::Llama4
            } else {
                panic!("Unsupported vision model {display_name}");
            };
            VisionLoaderBuilder::new(
                VisionSpecificConfig::default(),
                chat_template,
                tokenizer_json,
                Some(model_path.display().to_string()),
                jinja_explicit,
            )
115
            .build(Some(vlt))
116
117
118
        } else {
            // Load from a HF repo dir
            NormalLoaderBuilder::new(
119
120
121
                NormalSpecificConfig::default(),
                chat_template,
                tokenizer_json,
122
                Some(model_path.display().to_string()),
123
124
                no_kv_cache,
                jinja_explicit,
125
126
127
            )
            .build(None)?
        };
128

129
        let mut max_seq_len = model.card().context_length as usize;
130
131
132
133
        if max_seq_len == 0 {
            tracing::info!("context_length is 0. Probably error reading from model.");
            max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
        }
134

135
        // Paged attention requires cuda
136
        let paged_attention_config = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
137
            Some(PagedAttentionConfig::new(
138
                None, // Block size, default 32
139
                4096, // CPU memory in MiB
140
                MemoryGpuConfig::ContextSize(max_seq_len),
141
                PagedCacheType::Auto,
142
143
144
145
            )?)
        } else {
            None
        };
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

        let device_map_params = if is_vision_model(model.display_name()) {
            AutoDeviceMapParams::Vision {
                max_seq_len,
                max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
                max_image_shape: (0, 0),
                max_num_images: 0,
            }
        } else {
            AutoDeviceMapParams::Text {
                max_seq_len,
                max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
            }
        };

161
162
163
        // Load, into a Pipeline
        let pipeline = loader.load_model_from_hf(
            None,
164
            TokenSource::None, // The model was already downloaded
165
166
167
            &ModelDType::Auto,
            &best_device()?,
            false,
168
169
170
171
172
173
            DeviceMapSetting::Auto(device_map_params),
            if is_llama4(display_name) {
                Some(IsqType::Q4K)
            } else {
                None
            },
174
175
            paged_attention_config,
        )?;
176
        let scheduler = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
177
178
179
180
181
182
183
184
            tracing::debug!("Using mistralrs PagedAttentionMeta scheduler");
            let config = match pipeline.lock().await.get_metadata().cache_config.as_ref() {
                Some(conf) => conf.clone(),
                None => {
                    anyhow::bail!("Failed loading model config");
                }
            };
            SchedulerConfig::PagedAttentionMeta {
185
                max_num_seqs: PAGED_ATTENTION_MAX_NUM_SEQS,
186
187
188
189
190
                config,
            }
        } else {
            SchedulerConfig::DefaultScheduler {
                // Safety: unwrap trivially safe here
191
                method: DefaultSchedulerMethod::Fixed(NonZero::new(max_seq_len).unwrap()),
192
193
194
            }
        };
        // Create the MistralRs, which is a runner
195
196
197
198
199
200
201
202
203
        let throughput_logging = false;
        let search_embedding_model = None;
        let builder = MistralRsBuilder::new(
            pipeline.clone(),
            scheduler,
            throughput_logging,
            search_embedding_model,
        )
        .with_prefix_cache_n(16);
204
        let engine = MistralRsEngine {
205
            mistralrs: builder.build().await,
206
            context_length: max_seq_len,
207
            display_name: display_name.to_string(),
208
        };
209

210
211
        // skip the id used for dummy run https://github.com/EricLBuehler/mistral.rs/issues/1218
        let _ = engine.mistralrs.next_request_id();
212
213
214
215

        // Perform warmup request
        let (tx, mut rx) = channel(1);
        let request_id = engine.mistralrs.next_request_id();
216
        let warmup_request = Request::Normal(Box::new(NormalRequest {
217
            id: request_id,
218
            model_id: Some(display_name.to_string()),
219
220
221
222
223
224
225
226
227
228
            messages: RequestMessage::Chat {
                messages: vec![IndexMap::from([
                    ("role".to_string(), Either::Left("user".to_string())),
                    (
                        "content".to_string(),
                        Either::Left(WARMUP_MESSAGE.to_string()),
                    ),
                ])],
                enable_thinking: Some(false),
            },
229
230
231
232
233
234
235
236
237
238
            sampling_params: SamplingParams::deterministic(),
            response: tx,
            return_logprobs: false,
            is_streaming: false,
            constraint: Constraint::None,
            suffix: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
239
            web_search_options: None,
240
        }));
241
242

        // Send warmup request and consume response
243
        if let Ok(sender) = engine.mistralrs.get_sender(None) {
244
245
246
247
248
249
250
251
252
253
254
255
256
257
            if let Ok(()) = sender.send(warmup_request).await {
                if let Some(response) = rx.recv().await {
                    match response.as_result() {
                        Ok(r) => {
                            tracing::debug!(request_id, "Warmup response: {r:?}");
                        }
                        Err(err) => {
                            tracing::error!(request_id, %err, "Failed converting response to result.");
                        }
                    }
                }
            }
        }

258
        Ok(engine)
259
260
261
262
263
264
    }
}

#[async_trait]
impl
    AsyncEngine<
265
        SingleIn<NvCreateChatCompletionRequest>,
266
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
267
268
269
270
271
        Error,
    > for MistralRsEngine
{
    async fn generate(
        &self,
272
        request: SingleIn<NvCreateChatCompletionRequest>,
273
    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
274
275
276
277
278
        let (request, context) = request.transfer(());
        let ctx = context.context();
        let (tx, mut rx) = channel(10_000);

        let mut messages = vec![];
Paul Hendricks's avatar
Paul Hendricks committed
279
        for m in request.inner.messages {
280
            let dynamo_async_openai::types::ChatCompletionRequestMessage::User(inner_m) = m else {
Paul Hendricks's avatar
Paul Hendricks committed
281
282
                continue;
            };
283
            let dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(content) =
284
285
286
                inner_m.content
            else {
                anyhow::bail!("Only Text type chat completion supported");
287
288
            };
            let r = IndexMap::from([
Paul Hendricks's avatar
Paul Hendricks committed
289
                ("role".to_string(), Either::Left("user".to_string())),
290
291
292
293
294
295
296
297
                ("content".to_string(), Either::Left(content)),
            ]);
            messages.push(r);
        }
        if messages.is_empty() {
            anyhow::bail!("Empty request");
        }

298
        let det = SamplingParams::deterministic();
299
300
        // allow deprecated because max_tokens
        #[allow(deprecated)]
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
        let sampling_params = SamplingParams {
            temperature: request
                .inner
                .temperature
                .map(|t| t as f64)
                .or(det.temperature),
            top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
            top_n_logprobs: request
                .inner
                .top_logprobs
                .map(|t| t as usize)
                .unwrap_or(det.top_n_logprobs),
            frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
            presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
            stop_toks: request.inner.stop.map(to_stop_tokens).or(det.stop_toks),
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
            max_len: {
                let requested_max_tokens = request
                    .inner
                    .max_completion_tokens
                    .or(request.inner.max_tokens)
                    .map(|m| m as usize);

                // Ensure max_len doesn't exceed context length
                match requested_max_tokens {
                    Some(max_tokens) => Some(std::cmp::min(max_tokens, self.context_length)),
                    None => det
                        .max_len
                        .map(|len| std::cmp::min(len, self.context_length)),
                }
            },
331
332
333
334
335
336
337
338
339
340
341
            logits_bias: request
                .inner
                .logit_bias
                .map(to_logit_bias)
                .or(det.logits_bias),
            // These are not in async-openai yet
            top_k: det.top_k,
            min_p: det.min_p,
            n_choices: 1,
            dry_params: det.dry_params,
        };
342
        let request_id = self.mistralrs.next_request_id();
343
        let mistralrs_request = Request::Normal(Box::new(NormalRequest {
344
            id: request_id,
345
            model_id: Some(self.display_name.clone()),
346
347
348
349
            messages: RequestMessage::Chat {
                messages,
                enable_thinking: None,
            },
350
            sampling_params,
351
            response: tx,
352
            return_logprobs: request.inner.logprobs.unwrap_or_default(),
353
354
355
356
357
358
359
            is_streaming: true,
            constraint: Constraint::None,
            suffix: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
360
            web_search_options: None,
361
        }));
362

363
364
365
366
        self.mistralrs
            .get_sender(None)?
            .send(mistralrs_request)
            .await?;
367
368
369
370
371
372

        let output = stream! {
            while let Some(response) = rx.recv().await {
                let response = match response.as_result() {
                    Ok(r) => r,
                    Err(err) => {
373
                        tracing::error!(request_id, %err, "Failed converting mistralrs channel response to result.");
374
375
376
377
378
                        break;
                    }
                };
                match response {
                    ResponseOk::Chunk(c) => {
379
                        let Some(from_assistant) = c.choices[0].delta.content.clone() else {
380
                            tracing::warn!(request_id, "No content from mistralrs. Abandoning request.");
381
382
                            break;
                        };
383
384
385
386
387
                        let finish_reason = match &c.choices[0].finish_reason.as_deref() {
                            Some("stop") | Some("canceled") => {
                                Some(FinishReason::Stop)
                            }
                            Some("length") => {
Paul Hendricks's avatar
Paul Hendricks committed
388
                                Some(FinishReason::Length)
389
                            }
390
                            Some(s) => {
391
                                tracing::warn!(request_id, stop_reason = s, "Unknow stop reason");
392
393
                                Some(FinishReason::Stop)
                            }
394
395
396
397
                            None => None,
                        };
                        //tracing::trace!("from_assistant: {from_assistant}");

Paul Hendricks's avatar
Paul Hendricks committed
398
                        #[allow(deprecated)]
399
                        let delta = NvCreateChatCompletionStreamResponse {
400
                            id: c.id,
401
                            choices: vec![dynamo_async_openai::types::ChatChoiceStream{
402
                                index: 0,
403
                                delta: dynamo_async_openai::types::ChatCompletionStreamResponseDelta{
404
                                    //role: c.choices[0].delta.role,
405
                                    role: Some(dynamo_async_openai::types::Role::Assistant),
406
407
                                    content: Some(from_assistant),
                                    tool_calls: None,
Paul Hendricks's avatar
Paul Hendricks committed
408
409
                                    refusal: None,
                                    function_call: None,
410
                                    reasoning_content: None,
411
412
413
414
415
                                },
                                logprobs: None,
                                finish_reason,
                            }],
                            model: c.model,
Paul Hendricks's avatar
Paul Hendricks committed
416
                            created: c.created as u32,
417
418
419
420
421
422
423
424
425
426
427
428
429
430
                            object: c.object.clone(),
                            usage: None,
                            system_fingerprint: Some(c.system_fingerprint),
                            service_tier: None,
                        };
                        let ann = Annotated{
                            id: None,
                            data: Some(delta),
                            event: None,
                            comment: None,
                        };
                        yield ann;

                        if finish_reason.is_some() {
431
                            //tracing::trace!(request_id, "Finish reason: {finish_reason:?}");
432
433
434
                            break;
                        }
                    },
435
                    x => tracing::error!(request_id, "Unhandled. {x:?}"),
436
437
438
439
440
441
                }
            }
        };
        Ok(ResponseStream::new(Box::pin(output), ctx))
    }
}
442
443

/// openai stop tokens to mistralrs stop tokens
444
fn to_stop_tokens(t: dynamo_async_openai::types::Stop) -> StopTokens {
445
    match t {
446
447
        dynamo_async_openai::types::Stop::String(s) => StopTokens::Seqs(vec![s]),
        dynamo_async_openai::types::Stop::StringArray(v) => StopTokens::Seqs(v),
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    }
}

/// openai logit bias (strings/json) to mistralrs (u32/f32)
/// I think the input looks like this: {"3721": -100, "17765": 100}
fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
    let mut out = HashMap::new();
    for (key, value) in &lb {
        let token_id: u32 = match key.parse() {
            Ok(t) => t,
            Err(err) => {
                tracing::warn!(
                    "Unexpected logit_bias map. Key '{key}' is not an int: {lb:?}. {err}."
                );
                return HashMap::new();
            }
        };
        let Some(bias) = value.as_f64() else {
            tracing::warn!("Unexpected logit_bias map. Value '{value}' is not a float: {lb:?}");
            return HashMap::new();
        };
        out.insert(token_id, bias as f32);
    }
    out
}
473
474

#[async_trait]
475
476
477
478
479
480
impl
    AsyncEngine<
        SingleIn<NvCreateCompletionRequest>,
        ManyOut<Annotated<NvCreateCompletionResponse>>,
        Error,
    > for MistralRsEngine
481
482
483
{
    async fn generate(
        &self,
484
        request: SingleIn<NvCreateCompletionRequest>,
485
    ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        let (request, context) = request.transfer(());
        let ctx = context.context();
        let (tx, mut rx) = channel(10_000);
        let response_generator = request.response_generator();

        let messages = RequestMessage::Completion {
            text: prompt_to_string(&request.inner.prompt),
            echo_prompt: false,
            best_of: Some(1),
        };
        let det = SamplingParams::deterministic();
        // allow deprecated because max_tokens
        #[allow(deprecated)]
        let sampling_params = SamplingParams {
            temperature: request
                .inner
                .temperature
                .map(|t| t as f64)
                .or(det.temperature),
            top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
            top_n_logprobs: request
                .inner
                .logprobs
                .map(|t| t as usize)
                .unwrap_or(det.top_n_logprobs),
            frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
            presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
            stop_toks: request
                .inner
                .stop
                .clone()
                .map(to_stop_tokens)
                .or(det.stop_toks),
519
520
521
522
523
524
525
526
527
528
529
            max_len: {
                let requested_max_tokens = request.inner.max_tokens.map(|m| m as usize);

                // Ensure max_len doesn't exceed context length
                match requested_max_tokens {
                    Some(max_tokens) => Some(std::cmp::min(max_tokens, self.context_length)),
                    None => det
                        .max_len
                        .map(|len| std::cmp::min(len, self.context_length)),
                }
            },
530
531
532
533
534
535
536
537
538
539
540
541
542
543
            logits_bias: request
                .inner
                .logit_bias
                .clone()
                .map(to_logit_bias)
                .or(det.logits_bias),
            // These are not in async-openai yet
            top_k: det.top_k,
            min_p: det.min_p,
            n_choices: 1,
            dry_params: det.dry_params,
        };

        let request_id = self.mistralrs.next_request_id();
544
        let mistralrs_request = Request::Normal(Box::new(NormalRequest {
545
            id: request_id,
546
            model_id: Some(self.display_name.clone()),
547
548
549
550
551
552
553
554
555
556
557
            messages,
            sampling_params,
            response: tx,
            return_logprobs: false,
            is_streaming: true,
            constraint: Constraint::None,
            suffix: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
558
            web_search_options: None,
559
        }));
560

561
562
563
564
        self.mistralrs
            .get_sender(None)?
            .send(mistralrs_request)
            .await?;
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592

        let output = stream! {
            while let Some(response) = rx.recv().await {
                let response = match response.as_result() {
                    Ok(r) => r,
                    Err(err) => {
                        tracing::error!(request_id, %err, "Failed converting mistralrs channel response to result.");
                        break;
                    }
                };
                match response {
                    ResponseOk::CompletionChunk(c) => {
                        let from_assistant = c.choices[0].text.clone();

                        let finish_reason = match &c.choices[0].finish_reason.as_deref() {
                            Some("stop") | Some("canceled") => {
                                Some(FinishReason::Stop)
                            }
                            Some("length") => {
                                Some(FinishReason::Length)
                            }
                            Some(s) => {
                                tracing::warn!(request_id, stop_reason = s, "Unknow stop reason");
                                Some(FinishReason::Stop)
                            }
                            None => None,
                        };
                        #[allow(deprecated)]
Greg Clark's avatar
Greg Clark committed
593
                        let inner = response_generator.create_choice(0, Some(from_assistant), None, None);
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
                        let ann = Annotated{
                            id: None,
                            data: Some(inner),
                            event: None,
                            comment: None,
                        };
                        yield ann;

                        if finish_reason.is_some() {
                            break;
                        }
                    },
                    x => tracing::error!(request_id, "Unhandled. {x:?}"),
                }
            }
        };
        Ok(ResponseStream::new(Box::pin(output), ctx))
    }
}
613
614
615
616
617
618
619
620
621
622
623
624

fn is_vision_model(s: &str) -> bool {
    is_gemma3(s) || is_llama4(s)
}

fn is_gemma3(s: &str) -> bool {
    s.to_lowercase().contains("gemma-3")
}

fn is_llama4(s: &str) -> bool {
    s.to_lowercase().contains("llama-4")
}
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640

#[async_trait]
impl
    AsyncEngine<
        SingleIn<NvCreateEmbeddingRequest>,
        ManyOut<Annotated<NvCreateEmbeddingResponse>>,
        Error,
    > for MistralRsEngine
{
    async fn generate(
        &self,
        _request: SingleIn<NvCreateEmbeddingRequest>,
    ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
        unimplemented!()
    }
}