"tests/kernels/quantization/untest_machete_mm.py" did not exist on "ebe56a0064f7a72a5c51d4cd6bcca165590c5bed"
lib.rs 22.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
use std::collections::HashMap;
17
use std::{num::NonZero, sync::Arc};
18

Paul Hendricks's avatar
Paul Hendricks committed
19
use async_openai::types::FinishReason;
20
21
22
23
24
use async_stream::stream;
use async_trait::async_trait;
use either::Either;
use indexmap::IndexMap;
use mistralrs::{
25
    AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
26
    GGUFLoaderBuilder, GGUFSpecificConfig, IsqType, MemoryGpuConfig, MistralRs, MistralRsBuilder,
27
    ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
28
    Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, StopTokens, TokenSource,
29
    VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
30
31
32
};
use tokio::sync::mpsc::channel;

Neelay Shah's avatar
Neelay Shah committed
33
34
35
36
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::error as pipeline_error;
use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
use dynamo_runtime::protocols::annotated::Annotated;
37

38
39
40
use dynamo_llm::protocols::openai::{
    chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
    completions::{prompt_to_string, CompletionRequest, CompletionResponse},
41
    embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
42
};
43
44

use dynamo_llm::engines::{EngineDispatcher, StreamingEngine};
45
use dynamo_llm::local_model::LocalModel;
46

47
48
49
50
51
52
53
54
55
/// How many requests mistral will run at once in the paged attention scheduler.
/// It actually runs 1 fewer than this.
/// I would call this the batch size but apparently that's something else.
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 10;

/// Experimental: Switch this to true to enable paged attention on CUDA devices.
/// Under load (dynamo-run batch mode) paged attention sometimes returns an immediate
/// finish_reason=stop and no tokens for one of the requests.
const EXP_ENABLE_PAGED_ATTENTION: bool = false;
56

57
58
59
60
61
/// Initial message we send to mistral.rs to warm it up. We may not need this.
const WARMUP_MESSAGE: &str = "This is a test message. Respond only with 'OK'.";

pub async fn make_engine(model: &LocalModel) -> pipeline_error::Result<Arc<dyn StreamingEngine>> {
    let engine = MistralRsEngine::new(model).await?;
62
    let engine: Arc<dyn StreamingEngine> = Arc::new(EngineDispatcher::new(engine));
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    Ok(engine)
}

/// Gets the best device, cpu, cuda if compiled with CUDA
fn best_device() -> pipeline_error::Result<Device> {
    #[cfg(not(feature = "metal"))]
    {
        Ok(Device::cuda_if_available(0)?)
    }
    #[cfg(feature = "metal")]
    {
        Ok(Device::new_metal(0)?)
    }
}

struct MistralRsEngine {
    mistralrs: Arc<MistralRs>,
}

impl MistralRsEngine {
83
84
85
86
87
88
89
90
    async fn new(model: &LocalModel) -> pipeline_error::Result<Self> {
        let model_path = model.path();
        // Name some None's for clarity
        let chat_template = None;
        let tokenizer_json = None;
        let no_kv_cache = false;
        let jinja_explicit = None;
        let display_name = model.display_name();
91
92
93
94
95
96
97
98
        let loader = if model_path.is_file() {
            // Load from a GGUF
            let Some(model_filename) = model_path.file_name() else {
                pipeline_error::bail!("Missing filename in model path");
            };
            let Some(model_dir) = model_path.parent() else {
                pipeline_error::bail!("Invalid model path");
            };
99

100
            GGUFLoaderBuilder::new(
101
                chat_template,
102
103
104
105
106
107
108
                None,
                model_dir.display().to_string(),
                vec![model_filename.to_string_lossy().into_owned()],
                GGUFSpecificConfig {
                    prompt_chunksize: None,
                    topology: None,
                },
109
110
                no_kv_cache,
                jinja_explicit,
111
112
            )
            .build()
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        } else if is_vision_model(display_name) {
            let vlt = if is_gemma3(display_name) {
                VisionLoaderType::Gemma3
            } else if is_llama4(display_name) {
                VisionLoaderType::Llama4
            } else {
                panic!("Unsupported vision model {display_name}");
            };
            VisionLoaderBuilder::new(
                VisionSpecificConfig::default(),
                chat_template,
                tokenizer_json,
                Some(model_path.display().to_string()),
                jinja_explicit,
            )
            .build(vlt)
129
130
131
        } else {
            // Load from a HF repo dir
            NormalLoaderBuilder::new(
132
133
134
                NormalSpecificConfig::default(),
                chat_template,
                tokenizer_json,
135
                Some(model_path.display().to_string()),
136
137
                no_kv_cache,
                jinja_explicit,
138
139
140
            )
            .build(None)?
        };
141

142
143
        let max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;

144
        // Paged attention requires cuda
145
        let paged_attention_config = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
146
            Some(PagedAttentionConfig::new(
147
                None, // Block size, default 32
148
                4096, // CPU memory in MiB
149
                MemoryGpuConfig::ContextSize(max_seq_len),
150
151
152
153
            )?)
        } else {
            None
        };
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

        let device_map_params = if is_vision_model(model.display_name()) {
            AutoDeviceMapParams::Vision {
                max_seq_len,
                max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
                max_image_shape: (0, 0),
                max_num_images: 0,
            }
        } else {
            AutoDeviceMapParams::Text {
                max_seq_len,
                max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
            }
        };

169
170
171
        // Load, into a Pipeline
        let pipeline = loader.load_model_from_hf(
            None,
172
            TokenSource::None, // The model was already downloaded
173
174
175
            &ModelDType::Auto,
            &best_device()?,
            false,
176
177
178
179
180
181
            DeviceMapSetting::Auto(device_map_params),
            if is_llama4(display_name) {
                Some(IsqType::Q4K)
            } else {
                None
            },
182
183
            paged_attention_config,
        )?;
184
        let scheduler = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
185
186
187
188
189
190
191
192
            tracing::debug!("Using mistralrs PagedAttentionMeta scheduler");
            let config = match pipeline.lock().await.get_metadata().cache_config.as_ref() {
                Some(conf) => conf.clone(),
                None => {
                    anyhow::bail!("Failed loading model config");
                }
            };
            SchedulerConfig::PagedAttentionMeta {
193
                max_num_seqs: PAGED_ATTENTION_MAX_NUM_SEQS,
194
195
196
197
198
                config,
            }
        } else {
            SchedulerConfig::DefaultScheduler {
                // Safety: unwrap trivially safe here
199
                method: DefaultSchedulerMethod::Fixed(NonZero::new(max_seq_len).unwrap()),
200
201
202
            }
        };
        // Create the MistralRs, which is a runner
203
204
205
206
207
208
209
210
211
        let throughput_logging = false;
        let search_embedding_model = None;
        let builder = MistralRsBuilder::new(
            pipeline.clone(),
            scheduler,
            throughput_logging,
            search_embedding_model,
        )
        .with_prefix_cache_n(16);
212
        let engine = MistralRsEngine {
213
            mistralrs: builder.build(),
214
        };
215

216
217
        // skip the id used for dummy run https://github.com/EricLBuehler/mistral.rs/issues/1218
        let _ = engine.mistralrs.next_request_id();
218
219
220
221
222
223

        // Perform warmup request
        let (tx, mut rx) = channel(1);
        let request_id = engine.mistralrs.next_request_id();
        let warmup_request = Request::Normal(NormalRequest {
            id: request_id,
224
225
226
227
228
229
230
231
232
233
            messages: RequestMessage::Chat {
                messages: vec![IndexMap::from([
                    ("role".to_string(), Either::Left("user".to_string())),
                    (
                        "content".to_string(),
                        Either::Left(WARMUP_MESSAGE.to_string()),
                    ),
                ])],
                enable_thinking: Some(false),
            },
234
235
236
237
238
239
240
241
242
243
            sampling_params: SamplingParams::deterministic(),
            response: tx,
            return_logprobs: false,
            is_streaming: false,
            constraint: Constraint::None,
            suffix: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
244
            web_search_options: None,
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        });

        // Send warmup request and consume response
        if let Ok(sender) = engine.mistralrs.get_sender() {
            if let Ok(()) = sender.send(warmup_request).await {
                if let Some(response) = rx.recv().await {
                    match response.as_result() {
                        Ok(r) => {
                            tracing::debug!(request_id, "Warmup response: {r:?}");
                        }
                        Err(err) => {
                            tracing::error!(request_id, %err, "Failed converting response to result.");
                        }
                    }
                }
            }
        }

263
        Ok(engine)
264
265
266
267
268
269
    }
}

#[async_trait]
impl
    AsyncEngine<
270
        SingleIn<NvCreateChatCompletionRequest>,
271
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
272
273
274
275
276
        Error,
    > for MistralRsEngine
{
    async fn generate(
        &self,
277
        request: SingleIn<NvCreateChatCompletionRequest>,
278
    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
279
280
281
282
283
        let (request, context) = request.transfer(());
        let ctx = context.context();
        let (tx, mut rx) = channel(10_000);

        let mut messages = vec![];
Paul Hendricks's avatar
Paul Hendricks committed
284
285
286
287
        for m in request.inner.messages {
            let async_openai::types::ChatCompletionRequestMessage::User(inner_m) = m else {
                continue;
            };
288
289
290
291
            let async_openai::types::ChatCompletionRequestUserMessageContent::Text(content) =
                inner_m.content
            else {
                anyhow::bail!("Only Text type chat completion supported");
292
293
            };
            let r = IndexMap::from([
Paul Hendricks's avatar
Paul Hendricks committed
294
                ("role".to_string(), Either::Left("user".to_string())),
295
296
297
298
299
300
301
302
                ("content".to_string(), Either::Left(content)),
            ]);
            messages.push(r);
        }
        if messages.is_empty() {
            anyhow::bail!("Empty request");
        }

303
        let det = SamplingParams::deterministic();
304
305
        // allow deprecated because max_tokens
        #[allow(deprecated)]
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        let sampling_params = SamplingParams {
            temperature: request
                .inner
                .temperature
                .map(|t| t as f64)
                .or(det.temperature),
            top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
            top_n_logprobs: request
                .inner
                .top_logprobs
                .map(|t| t as usize)
                .unwrap_or(det.top_n_logprobs),
            frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
            presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
            stop_toks: request.inner.stop.map(to_stop_tokens).or(det.stop_toks),
            max_len: request
                .inner
                .max_completion_tokens
324
                .or(request.inner.max_tokens)
325
326
327
328
329
330
331
332
333
334
335
336
337
                .map(|m| m as usize)
                .or(det.max_len),
            logits_bias: request
                .inner
                .logit_bias
                .map(to_logit_bias)
                .or(det.logits_bias),
            // These are not in async-openai yet
            top_k: det.top_k,
            min_p: det.min_p,
            n_choices: 1,
            dry_params: det.dry_params,
        };
338
        let request_id = self.mistralrs.next_request_id();
339
        let mistralrs_request = Request::Normal(NormalRequest {
340
            id: request_id,
341
342
343
344
            messages: RequestMessage::Chat {
                messages,
                enable_thinking: None,
            },
345
            sampling_params,
346
            response: tx,
347
            return_logprobs: request.inner.logprobs.unwrap_or_default(),
348
349
350
351
352
353
354
            is_streaming: true,
            constraint: Constraint::None,
            suffix: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
355
            web_search_options: None,
356
357
358
359
360
361
362
363
364
        });

        self.mistralrs.get_sender()?.send(mistralrs_request).await?;

        let output = stream! {
            while let Some(response) = rx.recv().await {
                let response = match response.as_result() {
                    Ok(r) => r,
                    Err(err) => {
365
                        tracing::error!(request_id, %err, "Failed converting mistralrs channel response to result.");
366
367
368
369
370
                        break;
                    }
                };
                match response {
                    ResponseOk::Chunk(c) => {
371
                        let Some(from_assistant) = c.choices[0].delta.content.clone() else {
372
                            tracing::warn!(request_id, "No content from mistralrs. Abandoning request.");
373
374
                            break;
                        };
375
376
377
378
379
                        let finish_reason = match &c.choices[0].finish_reason.as_deref() {
                            Some("stop") | Some("canceled") => {
                                Some(FinishReason::Stop)
                            }
                            Some("length") => {
Paul Hendricks's avatar
Paul Hendricks committed
380
                                Some(FinishReason::Length)
381
                            }
382
                            Some(s) => {
383
                                tracing::warn!(request_id, stop_reason = s, "Unknow stop reason");
384
385
                                Some(FinishReason::Stop)
                            }
386
387
388
389
                            None => None,
                        };
                        //tracing::trace!("from_assistant: {from_assistant}");

Paul Hendricks's avatar
Paul Hendricks committed
390
391
                        #[allow(deprecated)]
                        let inner = async_openai::types::CreateChatCompletionStreamResponse{
392
                            id: c.id,
Paul Hendricks's avatar
Paul Hendricks committed
393
                            choices: vec![async_openai::types::ChatChoiceStream{
394
                                index: 0,
Paul Hendricks's avatar
Paul Hendricks committed
395
                                delta: async_openai::types::ChatCompletionStreamResponseDelta{
396
                                    //role: c.choices[0].delta.role,
Paul Hendricks's avatar
Paul Hendricks committed
397
                                    role: Some(async_openai::types::Role::Assistant),
398
399
                                    content: Some(from_assistant),
                                    tool_calls: None,
Paul Hendricks's avatar
Paul Hendricks committed
400
401
                                    refusal: None,
                                    function_call: None,
402
403
404
405
406
                                },
                                logprobs: None,
                                finish_reason,
                            }],
                            model: c.model,
Paul Hendricks's avatar
Paul Hendricks committed
407
                            created: c.created as u32,
408
409
410
411
412
                            object: c.object.clone(),
                            usage: None,
                            system_fingerprint: Some(c.system_fingerprint),
                            service_tier: None,
                        };
413
                        let delta = NvCreateChatCompletionStreamResponse{inner};
414
415
416
417
418
419
420
421
422
                        let ann = Annotated{
                            id: None,
                            data: Some(delta),
                            event: None,
                            comment: None,
                        };
                        yield ann;

                        if finish_reason.is_some() {
423
                            //tracing::trace!(request_id, "Finish reason: {finish_reason:?}");
424
425
426
                            break;
                        }
                    },
427
                    x => tracing::error!(request_id, "Unhandled. {x:?}"),
428
429
430
431
432
433
                }
            }
        };
        Ok(ResponseStream::new(Box::pin(output), ctx))
    }
}
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464

/// openai stop tokens to mistralrs stop tokens
fn to_stop_tokens(t: async_openai::types::Stop) -> StopTokens {
    match t {
        async_openai::types::Stop::String(s) => StopTokens::Seqs(vec![s]),
        async_openai::types::Stop::StringArray(v) => StopTokens::Seqs(v),
    }
}

/// openai logit bias (strings/json) to mistralrs (u32/f32)
/// I think the input looks like this: {"3721": -100, "17765": 100}
fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
    let mut out = HashMap::new();
    for (key, value) in &lb {
        let token_id: u32 = match key.parse() {
            Ok(t) => t,
            Err(err) => {
                tracing::warn!(
                    "Unexpected logit_bias map. Key '{key}' is not an int: {lb:?}. {err}."
                );
                return HashMap::new();
            }
        };
        let Some(bias) = value.as_f64() else {
            tracing::warn!("Unexpected logit_bias map. Value '{value}' is not a float: {lb:?}");
            return HashMap::new();
        };
        out.insert(token_id, bias as f32);
    }
    out
}
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539

#[async_trait]
impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error>
    for MistralRsEngine
{
    async fn generate(
        &self,
        request: SingleIn<CompletionRequest>,
    ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
        let (request, context) = request.transfer(());
        let ctx = context.context();
        let (tx, mut rx) = channel(10_000);
        let response_generator = request.response_generator();

        let messages = RequestMessage::Completion {
            text: prompt_to_string(&request.inner.prompt),
            echo_prompt: false,
            best_of: Some(1),
        };
        let det = SamplingParams::deterministic();
        // allow deprecated because max_tokens
        #[allow(deprecated)]
        let sampling_params = SamplingParams {
            temperature: request
                .inner
                .temperature
                .map(|t| t as f64)
                .or(det.temperature),
            top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
            top_n_logprobs: request
                .inner
                .logprobs
                .map(|t| t as usize)
                .unwrap_or(det.top_n_logprobs),
            frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
            presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
            stop_toks: request
                .inner
                .stop
                .clone()
                .map(to_stop_tokens)
                .or(det.stop_toks),
            max_len: request
                .inner
                .max_tokens
                .or(request.inner.max_tokens)
                .map(|m| m as usize)
                .or(det.max_len),
            logits_bias: request
                .inner
                .logit_bias
                .clone()
                .map(to_logit_bias)
                .or(det.logits_bias),
            // These are not in async-openai yet
            top_k: det.top_k,
            min_p: det.min_p,
            n_choices: 1,
            dry_params: det.dry_params,
        };

        let request_id = self.mistralrs.next_request_id();
        let mistralrs_request = Request::Normal(NormalRequest {
            id: request_id,
            messages,
            sampling_params,
            response: tx,
            return_logprobs: false,
            is_streaming: true,
            constraint: Constraint::None,
            suffix: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
540
            web_search_options: None,
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        });

        self.mistralrs.get_sender()?.send(mistralrs_request).await?;

        let output = stream! {
            while let Some(response) = rx.recv().await {
                let response = match response.as_result() {
                    Ok(r) => r,
                    Err(err) => {
                        tracing::error!(request_id, %err, "Failed converting mistralrs channel response to result.");
                        break;
                    }
                };
                match response {
                    ResponseOk::CompletionChunk(c) => {
                        let from_assistant = c.choices[0].text.clone();

                        let finish_reason = match &c.choices[0].finish_reason.as_deref() {
                            Some("stop") | Some("canceled") => {
                                Some(FinishReason::Stop)
                            }
                            Some("length") => {
                                Some(FinishReason::Length)
                            }
                            Some(s) => {
                                tracing::warn!(request_id, stop_reason = s, "Unknow stop reason");
                                Some(FinishReason::Stop)
                            }
                            None => None,
                        };
                        #[allow(deprecated)]
                        let inner = response_generator.create_choice(0, Some(from_assistant), None);
                        let ann = Annotated{
                            id: None,
                            data: Some(inner),
                            event: None,
                            comment: None,
                        };
                        yield ann;

                        if finish_reason.is_some() {
                            break;
                        }
                    },
                    x => tracing::error!(request_id, "Unhandled. {x:?}"),
                }
            }
        };
        Ok(ResponseStream::new(Box::pin(output), ctx))
    }
}
592
593
594
595
596
597
598
599
600
601
602
603

fn is_vision_model(s: &str) -> bool {
    is_gemma3(s) || is_llama4(s)
}

fn is_gemma3(s: &str) -> bool {
    s.to_lowercase().contains("gemma-3")
}

fn is_llama4(s: &str) -> bool {
    s.to_lowercase().contains("llama-4")
}
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619

#[async_trait]
impl
    AsyncEngine<
        SingleIn<NvCreateEmbeddingRequest>,
        ManyOut<Annotated<NvCreateEmbeddingResponse>>,
        Error,
    > for MistralRsEngine
{
    async fn generate(
        &self,
        _request: SingleIn<NvCreateEmbeddingRequest>,
    ) -> Result<ManyOut<Annotated<NvCreateEmbeddingResponse>>, Error> {
        unimplemented!()
    }
}