lib.rs 20.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
use std::collections::HashMap;
17
use std::{num::NonZero, path::Path, sync::Arc};
18

Paul Hendricks's avatar
Paul Hendricks committed
19
use async_openai::types::FinishReason;
20
21
22
23
24
use async_stream::stream;
use async_trait::async_trait;
use either::Either;
use indexmap::IndexMap;
use mistralrs::{
25
    AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
26
    GGUFLoaderBuilder, GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder,
27
    ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
28
    Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, StopTokens, TokenSource,
29
30
31
};
use tokio::sync::mpsc::channel;

Neelay Shah's avatar
Neelay Shah committed
32
33
34
35
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::error as pipeline_error;
use dynamo_runtime::pipeline::{Error, ManyOut, SingleIn};
use dynamo_runtime::protocols::annotated::Annotated;
36

37
38
39
use dynamo_llm::protocols::openai::{
    chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
    completions::{prompt_to_string, CompletionRequest, CompletionResponse},
40
};
41
42

use dynamo_llm::engines::{EngineDispatcher, StreamingEngine};
43

44
45
46
47
48
49
50
51
52
/// How many requests mistral will run at once in the paged attention scheduler.
/// It actually runs 1 fewer than this.
/// I would call this the batch size but apparently that's something else.
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 10;

/// Experimental: Switch this to true to enable paged attention on CUDA devices.
/// Under load (dynamo-run batch mode) paged attention sometimes returns an immediate
/// finish_reason=stop and no tokens for one of the requests.
const EXP_ENABLE_PAGED_ATTENTION: bool = false;
53

54
pub async fn make_engine(gguf_path: &Path) -> pipeline_error::Result<Arc<dyn StreamingEngine>> {
55
    let engine = MistralRsEngine::new(gguf_path).await?;
56
    let engine: Arc<dyn StreamingEngine> = Arc::new(EngineDispatcher::new(engine));
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    Ok(engine)
}

/// Gets the best device, cpu, cuda if compiled with CUDA
fn best_device() -> pipeline_error::Result<Device> {
    #[cfg(not(feature = "metal"))]
    {
        Ok(Device::cuda_if_available(0)?)
    }
    #[cfg(feature = "metal")]
    {
        Ok(Device::new_metal(0)?)
    }
}

struct MistralRsEngine {
    mistralrs: Arc<MistralRs>,
}

impl MistralRsEngine {
    async fn new(model_path: &Path) -> pipeline_error::Result<Self> {
78
79
80
81
82
83
84
85
        let loader = if model_path.is_file() {
            // Load from a GGUF
            let Some(model_filename) = model_path.file_name() else {
                pipeline_error::bail!("Missing filename in model path");
            };
            let Some(model_dir) = model_path.parent() else {
                pipeline_error::bail!("Invalid model path");
            };
86

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
            GGUFLoaderBuilder::new(
                None,
                None,
                model_dir.display().to_string(),
                vec![model_filename.to_string_lossy().into_owned()],
                GGUFSpecificConfig {
                    prompt_chunksize: None,
                    topology: None,
                },
            )
            .build()
        } else {
            // Load from a HF repo dir
            NormalLoaderBuilder::new(
                NormalSpecificConfig {
                    use_flash_attn: false,
                    prompt_chunksize: None,
                    topology: None,
                    organization: Default::default(),
                    write_uqff: None,
                    from_uqff: None,
                    imatrix: None,
                    calibration_file: None,
                },
                None,
                None,
                Some(model_path.display().to_string()),
            )
            .build(None)?
        };
117

118
119
        let max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;

120
        // Paged attention requires cuda
121
        let paged_attention_config = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
122
            Some(PagedAttentionConfig::new(
123
                None, // Block size, default 32
124
                4096, // CPU memory in MiB
125
                MemoryGpuConfig::ContextSize(max_seq_len),
126
127
128
129
130
131
132
            )?)
        } else {
            None
        };
        // Load, into a Pipeline
        let pipeline = loader.load_model_from_hf(
            None,
133
            TokenSource::None, // The model was already downloaded
134
135
136
            &ModelDType::Auto,
            &best_device()?,
            false,
137
            DeviceMapSetting::Auto(AutoDeviceMapParams::Text {
138
139
                max_seq_len,
                max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
140
            }),
141
142
143
            None,
            paged_attention_config,
        )?;
144
        let scheduler = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
145
146
147
148
149
150
151
152
            tracing::debug!("Using mistralrs PagedAttentionMeta scheduler");
            let config = match pipeline.lock().await.get_metadata().cache_config.as_ref() {
                Some(conf) => conf.clone(),
                None => {
                    anyhow::bail!("Failed loading model config");
                }
            };
            SchedulerConfig::PagedAttentionMeta {
153
                max_num_seqs: PAGED_ATTENTION_MAX_NUM_SEQS,
154
155
156
157
158
159
                config,
            }
        } else {
            tracing::debug!("Using mistralrs DefaultScheduler");
            SchedulerConfig::DefaultScheduler {
                // Safety: unwrap trivially safe here
160
                method: DefaultSchedulerMethod::Fixed(NonZero::new(max_seq_len).unwrap()),
161
162
163
            }
        };
        // Create the MistralRs, which is a runner
164
        let builder = MistralRsBuilder::new(pipeline.clone(), scheduler).with_prefix_cache_n(16);
165
        let engine = MistralRsEngine {
166
            mistralrs: builder.build(),
167
        };
168

169
170
        // skip the id used for dummy run https://github.com/EricLBuehler/mistral.rs/issues/1218
        let _ = engine.mistralrs.next_request_id();
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

        // Perform warmup request
        let (tx, mut rx) = channel(1);
        let request_id = engine.mistralrs.next_request_id();
        let warmup_request = Request::Normal(NormalRequest {
            id: request_id,
            messages: RequestMessage::Chat(vec![IndexMap::from([
                ("role".to_string(), Either::Left("user".to_string())),
                ("content".to_string(), Either::Left("test".to_string())),
            ])]),
            sampling_params: SamplingParams::deterministic(),
            response: tx,
            return_logprobs: false,
            is_streaming: false,
            constraint: Constraint::None,
            suffix: None,
            adapters: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
        });

        // Send warmup request and consume response
        if let Ok(sender) = engine.mistralrs.get_sender() {
            if let Ok(()) = sender.send(warmup_request).await {
                if let Some(response) = rx.recv().await {
                    match response.as_result() {
                        Ok(r) => {
                            tracing::debug!(request_id, "Warmup response: {r:?}");
                        }
                        Err(err) => {
                            tracing::error!(request_id, %err, "Failed converting response to result.");
                        }
                    }
                }
            }
        }

210
        Ok(engine)
211
212
213
214
215
216
    }
}

#[async_trait]
impl
    AsyncEngine<
217
        SingleIn<NvCreateChatCompletionRequest>,
218
        ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
219
220
221
222
223
        Error,
    > for MistralRsEngine
{
    async fn generate(
        &self,
224
        request: SingleIn<NvCreateChatCompletionRequest>,
225
    ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
226
227
228
229
230
        let (request, context) = request.transfer(());
        let ctx = context.context();
        let (tx, mut rx) = channel(10_000);

        let mut messages = vec![];
Paul Hendricks's avatar
Paul Hendricks committed
231
232
233
234
        for m in request.inner.messages {
            let async_openai::types::ChatCompletionRequestMessage::User(inner_m) = m else {
                continue;
            };
235
236
237
238
            let async_openai::types::ChatCompletionRequestUserMessageContent::Text(content) =
                inner_m.content
            else {
                anyhow::bail!("Only Text type chat completion supported");
239
240
            };
            let r = IndexMap::from([
Paul Hendricks's avatar
Paul Hendricks committed
241
                ("role".to_string(), Either::Left("user".to_string())),
242
243
244
245
246
247
248
249
                ("content".to_string(), Either::Left(content)),
            ]);
            messages.push(r);
        }
        if messages.is_empty() {
            anyhow::bail!("Empty request");
        }

250
        let det = SamplingParams::deterministic();
251
252
        // allow deprecated because max_tokens
        #[allow(deprecated)]
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        let sampling_params = SamplingParams {
            temperature: request
                .inner
                .temperature
                .map(|t| t as f64)
                .or(det.temperature),
            top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
            top_n_logprobs: request
                .inner
                .top_logprobs
                .map(|t| t as usize)
                .unwrap_or(det.top_n_logprobs),
            frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
            presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
            stop_toks: request.inner.stop.map(to_stop_tokens).or(det.stop_toks),
            max_len: request
                .inner
                .max_completion_tokens
271
                .or(request.inner.max_tokens)
272
273
274
275
276
277
278
279
280
281
282
283
284
                .map(|m| m as usize)
                .or(det.max_len),
            logits_bias: request
                .inner
                .logit_bias
                .map(to_logit_bias)
                .or(det.logits_bias),
            // These are not in async-openai yet
            top_k: det.top_k,
            min_p: det.min_p,
            n_choices: 1,
            dry_params: det.dry_params,
        };
285
        let request_id = self.mistralrs.next_request_id();
286
        let mistralrs_request = Request::Normal(NormalRequest {
287
            id: request_id,
288
            messages: RequestMessage::Chat(messages),
289
            sampling_params,
290
            response: tx,
291
            return_logprobs: request.inner.logprobs.unwrap_or_default(),
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
            is_streaming: true,
            constraint: Constraint::None,
            suffix: None,
            adapters: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
        });

        self.mistralrs.get_sender()?.send(mistralrs_request).await?;

        let output = stream! {
            while let Some(response) = rx.recv().await {
                let response = match response.as_result() {
                    Ok(r) => r,
                    Err(err) => {
309
                        tracing::error!(request_id, %err, "Failed converting mistralrs channel response to result.");
310
311
312
313
314
                        break;
                    }
                };
                match response {
                    ResponseOk::Chunk(c) => {
315
                        let Some(from_assistant) = c.choices[0].delta.content.clone() else {
316
                            tracing::warn!(request_id, "No content from mistralrs. Abandoning request.");
317
318
                            break;
                        };
319
320
321
322
323
                        let finish_reason = match &c.choices[0].finish_reason.as_deref() {
                            Some("stop") | Some("canceled") => {
                                Some(FinishReason::Stop)
                            }
                            Some("length") => {
Paul Hendricks's avatar
Paul Hendricks committed
324
                                Some(FinishReason::Length)
325
                            }
326
                            Some(s) => {
327
                                tracing::warn!(request_id, stop_reason = s, "Unknow stop reason");
328
329
                                Some(FinishReason::Stop)
                            }
330
331
332
333
                            None => None,
                        };
                        //tracing::trace!("from_assistant: {from_assistant}");

Paul Hendricks's avatar
Paul Hendricks committed
334
335
                        #[allow(deprecated)]
                        let inner = async_openai::types::CreateChatCompletionStreamResponse{
336
                            id: c.id,
Paul Hendricks's avatar
Paul Hendricks committed
337
                            choices: vec![async_openai::types::ChatChoiceStream{
338
                                index: 0,
Paul Hendricks's avatar
Paul Hendricks committed
339
                                delta: async_openai::types::ChatCompletionStreamResponseDelta{
340
                                    //role: c.choices[0].delta.role,
Paul Hendricks's avatar
Paul Hendricks committed
341
                                    role: Some(async_openai::types::Role::Assistant),
342
343
                                    content: Some(from_assistant),
                                    tool_calls: None,
Paul Hendricks's avatar
Paul Hendricks committed
344
345
                                    refusal: None,
                                    function_call: None,
346
347
348
349
350
                                },
                                logprobs: None,
                                finish_reason,
                            }],
                            model: c.model,
Paul Hendricks's avatar
Paul Hendricks committed
351
                            created: c.created as u32,
352
353
354
355
356
                            object: c.object.clone(),
                            usage: None,
                            system_fingerprint: Some(c.system_fingerprint),
                            service_tier: None,
                        };
357
                        let delta = NvCreateChatCompletionStreamResponse{inner};
358
359
360
361
362
363
364
365
366
                        let ann = Annotated{
                            id: None,
                            data: Some(delta),
                            event: None,
                            comment: None,
                        };
                        yield ann;

                        if finish_reason.is_some() {
367
                            //tracing::trace!(request_id, "Finish reason: {finish_reason:?}");
368
369
370
                            break;
                        }
                    },
371
                    x => tracing::error!(request_id, "Unhandled. {x:?}"),
372
373
374
375
376
377
                }
            }
        };
        Ok(ResponseStream::new(Box::pin(output), ctx))
    }
}
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408

/// openai stop tokens to mistralrs stop tokens
fn to_stop_tokens(t: async_openai::types::Stop) -> StopTokens {
    match t {
        async_openai::types::Stop::String(s) => StopTokens::Seqs(vec![s]),
        async_openai::types::Stop::StringArray(v) => StopTokens::Seqs(v),
    }
}

/// openai logit bias (strings/json) to mistralrs (u32/f32)
/// I think the input looks like this: {"3721": -100, "17765": 100}
fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
    let mut out = HashMap::new();
    for (key, value) in &lb {
        let token_id: u32 = match key.parse() {
            Ok(t) => t,
            Err(err) => {
                tracing::warn!(
                    "Unexpected logit_bias map. Key '{key}' is not an int: {lb:?}. {err}."
                );
                return HashMap::new();
            }
        };
        let Some(bias) = value.as_f64() else {
            tracing::warn!("Unexpected logit_bias map. Value '{value}' is not a float: {lb:?}");
            return HashMap::new();
        };
        out.insert(token_id, bias as f32);
    }
    out
}
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535

#[async_trait]
impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error>
    for MistralRsEngine
{
    async fn generate(
        &self,
        request: SingleIn<CompletionRequest>,
    ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
        let (request, context) = request.transfer(());
        let ctx = context.context();
        let (tx, mut rx) = channel(10_000);
        let response_generator = request.response_generator();

        let messages = RequestMessage::Completion {
            text: prompt_to_string(&request.inner.prompt),
            echo_prompt: false,
            best_of: Some(1),
        };
        let det = SamplingParams::deterministic();
        // allow deprecated because max_tokens
        #[allow(deprecated)]
        let sampling_params = SamplingParams {
            temperature: request
                .inner
                .temperature
                .map(|t| t as f64)
                .or(det.temperature),
            top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
            top_n_logprobs: request
                .inner
                .logprobs
                .map(|t| t as usize)
                .unwrap_or(det.top_n_logprobs),
            frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
            presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
            stop_toks: request
                .inner
                .stop
                .clone()
                .map(to_stop_tokens)
                .or(det.stop_toks),
            max_len: request
                .inner
                .max_tokens
                .or(request.inner.max_tokens)
                .map(|m| m as usize)
                .or(det.max_len),
            logits_bias: request
                .inner
                .logit_bias
                .clone()
                .map(to_logit_bias)
                .or(det.logits_bias),
            // These are not in async-openai yet
            top_k: det.top_k,
            min_p: det.min_p,
            n_choices: 1,
            dry_params: det.dry_params,
        };

        let request_id = self.mistralrs.next_request_id();
        let mistralrs_request = Request::Normal(NormalRequest {
            id: request_id,
            messages,
            sampling_params,
            response: tx,
            return_logprobs: false,
            is_streaming: true,
            constraint: Constraint::None,
            suffix: None,
            adapters: None,
            tools: None,
            tool_choice: None,
            logits_processors: None,
            return_raw_logits: false,
        });

        self.mistralrs.get_sender()?.send(mistralrs_request).await?;

        let output = stream! {
            while let Some(response) = rx.recv().await {
                let response = match response.as_result() {
                    Ok(r) => r,
                    Err(err) => {
                        tracing::error!(request_id, %err, "Failed converting mistralrs channel response to result.");
                        break;
                    }
                };
                match response {
                    ResponseOk::CompletionChunk(c) => {
                        let from_assistant = c.choices[0].text.clone();

                        let finish_reason = match &c.choices[0].finish_reason.as_deref() {
                            Some("stop") | Some("canceled") => {
                                Some(FinishReason::Stop)
                            }
                            Some("length") => {
                                Some(FinishReason::Length)
                            }
                            Some(s) => {
                                tracing::warn!(request_id, stop_reason = s, "Unknow stop reason");
                                Some(FinishReason::Stop)
                            }
                            None => None,
                        };
                        #[allow(deprecated)]
                        let inner = response_generator.create_choice(0, Some(from_assistant), None);
                        let ann = Annotated{
                            id: None,
                            data: Some(inner),
                            event: None,
                            comment: None,
                        };
                        yield ann;

                        if finish_reason.is_some() {
                            break;
                        }
                    },
                    x => tracing::error!(request_id, "Unhandled. {x:?}"),
                }
            }
        };
        Ok(ResponseStream::new(Box::pin(output), ctx))
    }
}