Unverified Commit 4ab47617 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(engines): Upgrade mistralrs to 0.6.0 (#1767)

parent 7a353e61
This diff is collapsed.
......@@ -76,7 +76,7 @@ tokio-util = { version = "0.7", features = ["codec", "net"] }
tracing = { version = "0.1" }
tracing-subscriber = { version = "0.3", features = ["env-filter", "local-time", "json"] }
validator = { version = "0.20.0", features = ["derive"] }
uuid = { version = "1", features = ["v4", "serde"] }
uuid = { version = "1.17", features = ["v4", "serde"] }
url = {version = "2.5", features = ["serde"]}
xxhash-rust = { version = "0.8", features = ["xxh3", "const_xxh3"] }
......
......@@ -30,7 +30,9 @@ allow = [
"OpenSSL",
"Unicode-3.0",
"BSL-1.0",
"MPL-2.0"
"MPL-2.0",
"CDLA-Permissive-2.0",
"Zlib"
]
# TODO exceptions
......
......@@ -5133,12 +5133,14 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
version = "1.16.0"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9"
checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d"
dependencies = [
"getrandom 0.3.2",
"js-sys",
"serde",
"wasm-bindgen",
]
[[package]]
......
......@@ -26,7 +26,7 @@ keywords.workspace = true
[features]
default = []
cuda = ["mistralrs/cuda", "candle-core/cuda"]
cuda = ["mistralrs/cuda"]
metal = ["mistralrs/metal"]
[dependencies]
......@@ -37,10 +37,9 @@ anyhow = { workspace = true }
async-openai = { workspace = true }
async-stream = { workspace = true }
async-trait = { workspace = true }
candle-core = { version = "0.8.0" }
either = { workspace = true }
indexmap = { version = "2.6" }
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "ebd50e35e" }
indexmap = { version = "2.9.0", features = ["serde"] }
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", version = "0.6.0" }
serde_json = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
......@@ -13,8 +13,8 @@ use mistralrs::{
AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
GGUFLoaderBuilder, GGUFSpecificConfig, IsqType, MemoryGpuConfig, MistralRs, MistralRsBuilder,
ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, StopTokens, TokenSource,
VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
PagedCacheType, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig,
StopTokens, TokenSource, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
};
use tokio::sync::mpsc::channel;
......@@ -66,6 +66,7 @@ fn best_device() -> pipeline_error::Result<Device> {
struct MistralRsEngine {
mistralrs: Arc<MistralRs>,
context_length: usize,
display_name: String,
}
impl MistralRsEngine {
......@@ -114,7 +115,7 @@ impl MistralRsEngine {
Some(model_path.display().to_string()),
jinja_explicit,
)
.build(vlt)
.build(Some(vlt))
} else {
// Load from a HF repo dir
NormalLoaderBuilder::new(
......@@ -140,6 +141,7 @@ impl MistralRsEngine {
None, // Block size, default 32
4096, // CPU memory in MiB
MemoryGpuConfig::ContextSize(max_seq_len),
PagedCacheType::Auto,
)?)
} else {
None
......@@ -203,8 +205,9 @@ impl MistralRsEngine {
)
.with_prefix_cache_n(16);
let engine = MistralRsEngine {
mistralrs: builder.build(),
mistralrs: builder.build().await,
context_length: max_seq_len,
display_name: display_name.to_string(),
};
// skip the id used for dummy run https://github.com/EricLBuehler/mistral.rs/issues/1218
......@@ -213,8 +216,9 @@ impl MistralRsEngine {
// Perform warmup request
let (tx, mut rx) = channel(1);
let request_id = engine.mistralrs.next_request_id();
let warmup_request = Request::Normal(NormalRequest {
let warmup_request = Request::Normal(Box::new(NormalRequest {
id: request_id,
model_id: Some(display_name.to_string()),
messages: RequestMessage::Chat {
messages: vec![IndexMap::from([
("role".to_string(), Either::Left("user".to_string())),
......@@ -236,10 +240,10 @@ impl MistralRsEngine {
logits_processors: None,
return_raw_logits: false,
web_search_options: None,
});
}));
// Send warmup request and consume response
if let Ok(sender) = engine.mistralrs.get_sender() {
if let Ok(sender) = engine.mistralrs.get_sender(None) {
if let Ok(()) = sender.send(warmup_request).await {
if let Some(response) = rx.recv().await {
match response.as_result() {
......@@ -339,8 +343,9 @@ impl
dry_params: det.dry_params,
};
let request_id = self.mistralrs.next_request_id();
let mistralrs_request = Request::Normal(NormalRequest {
let mistralrs_request = Request::Normal(Box::new(NormalRequest {
id: request_id,
model_id: Some(self.display_name.clone()),
messages: RequestMessage::Chat {
messages,
enable_thinking: None,
......@@ -356,9 +361,12 @@ impl
logits_processors: None,
return_raw_logits: false,
web_search_options: None,
});
}));
self.mistralrs.get_sender()?.send(mistralrs_request).await?;
self.mistralrs
.get_sender(None)?
.send(mistralrs_request)
.await?;
let output = stream! {
while let Some(response) = rx.recv().await {
......@@ -536,8 +544,9 @@ impl
};
let request_id = self.mistralrs.next_request_id();
let mistralrs_request = Request::Normal(NormalRequest {
let mistralrs_request = Request::Normal(Box::new(NormalRequest {
id: request_id,
model_id: Some(self.display_name.clone()),
messages,
sampling_params,
response: tx,
......@@ -550,9 +559,12 @@ impl
logits_processors: None,
return_raw_logits: false,
web_search_options: None,
});
}));
self.mistralrs.get_sender()?.send(mistralrs_request).await?;
self.mistralrs
.get_sender(None)?
.send(mistralrs_request)
.await?;
let output = stream! {
while let Some(response) = rx.recv().await {
......
......@@ -2951,12 +2951,14 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "uuid"
version = "1.16.0"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9"
checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d"
dependencies = [
"getrandom 0.3.2",
"js-sys",
"serde",
"wasm-bindgen",
]
[[package]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment