Commit 2f700421 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Add a mistralrs engine to tio (#178)

This allows us to run a real model.

Build:
```
cargo build --release --features mistralrs,cuda
```

Run:
```
./target/release/tio in=text out=mistralrs --model-path Llama-3.2-1B-Instruct-Q4_K_M.gguf
```

Why [mistral.rs](https://github.com/EricLBuehler/mistral.rs)?

- It has no dependencies. You don't need a container or a virtual env to get started.
- It supports CUDA, Metal (MacOS) and CPU-only. Everyone can join the AI revolution.
- It starts fast and serves fast (with CUDA). That makes it fun to experiment with.
- It runs many models, not just Mistral, that's just it's name.
parent 6d2abdba
...@@ -124,7 +124,7 @@ $ignored_files = @('.clang-format', '.gitattributes', '.gitignore', '.gitkeep', ...@@ -124,7 +124,7 @@ $ignored_files = @('.clang-format', '.gitattributes', '.gitignore', '.gitkeep',
write-debug "<copyright-check> ignored_files = ['$($ignored_files -join "','")']." write-debug "<copyright-check> ignored_files = ['$($ignored_files -join "','")']."
$ignored_paths = @('.github', '.mypy_cache', '.pytest_cache') $ignored_paths = @('.github', '.mypy_cache', '.pytest_cache')
write-debug "<copyright-check> ignored_paths = ['$($ignored_paths -join "','")']." write-debug "<copyright-check> ignored_paths = ['$($ignored_paths -join "','")']."
$ignored_types = @('.bat', '.gif', '.ico', '.ipynb', '.jpg', '.jpeg', '.patch', '.png', '.pyc', '.pyi', '.rst', '.zip') $ignored_types = @('.bat', '.gif', '.ico', '.ipynb', '.jpg', '.jpeg', '.patch', '.png', '.pyc', '.pyi', '.rst', '.zip', '.md')
write-debug "<copyright-check> ignored_types = ['$($ignored_types -join "', '")']." write-debug "<copyright-check> ignored_types = ['$($ignored_types -join "', '")']."
$ignored_folders = @('.git', '__pycache__') $ignored_folders = @('.git', '__pycache__')
......
This diff is collapsed.
...@@ -20,6 +20,11 @@ edition = "2021" ...@@ -20,6 +20,11 @@ edition = "2021"
authors = ["NVIDIA"] authors = ["NVIDIA"]
homepage = "https://github.com/triton-inference-server/triton_distributed" homepage = "https://github.com/triton-inference-server/triton_distributed"
[features]
mistralrs = ["triton-llm/mistralrs"]
cuda = ["triton-llm/cuda"]
metal = ["triton-llm/metal"]
[dependencies] [dependencies]
anyhow = "1" anyhow = "1"
async-stream = { version = "0.3" } async-stream = { version = "0.3" }
......
# triton-llm service runner
`tio` is a tool for exploring the triton-distributed and triton-llm components.
## Install and start pre-requisites
Rust:
```bash
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
```
Get the NATS server from https://nats.io/download/ and run it:
```
nats-server -js --trace --store_dir $(mktemp -d)
```
Get etcd from https://github.com/etcd-io/etcd/releases and run it: `etcd`
These components are required but not yet used by tio. It's a journey, OK.
## Build
- CUDA:
`cargo build --release --features mistralrs,cuda`
- MAC w/ Metal:
`cargo build --release --features mistralrs,metal`
- CPU only:
`cargo build --release --features mistralrs`
## Download a model from Hugging Face
For example one of these should be fast and good quality on almost any machine: https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF
## Run
*Text interface*
`./target/release/tio in=text out=mistralrs --model-path Llama-3.2-1B-Instruct-Q4_K_M.gguf`
*HTTP interface*
`./target/release/tio in=http out=mistralrs --model-path Llama-3.2-1B-Instruct-Q4_K_M.gguf`
List the models: `curl localhost:8080/v1/models`
Send a request:
```
curl -d '{"model": "Llama-3.2-1B-Instruct-Q4_K_M", "max_tokens": 2049, "messages":[{"role":"user", "content": "What is the capital of South Africa?" }]}' -H 'Content-Type: application/json' http://localhost:8080/v1/chat/completions
```
Run `tio --help` for more options.
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::path::PathBuf;
use triton_distributed::runtime::CancellationToken; use triton_distributed::runtime::CancellationToken;
use triton_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine; use triton_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
...@@ -30,9 +32,14 @@ pub struct Flags { ...@@ -30,9 +32,14 @@ pub struct Flags {
pub http_port: u16, pub http_port: u16,
/// The name of the model we are serving /// The name of the model we are serving
/// Later that will come from the HF repo name, and still later from etcd during discovery
#[arg(long)] #[arg(long)]
pub model_name: String, pub model_name: Option<String>,
/// Full path to the model. This differs by engine:
/// - mistralrs: File. GGUF.
/// - echo_full: Omit the flag.
#[arg(long)]
pub model_path: Option<PathBuf>,
} }
pub enum EngineConfig { pub enum EngineConfig {
...@@ -49,12 +56,44 @@ pub async fn run( ...@@ -49,12 +56,44 @@ pub async fn run(
flags: Flags, flags: Flags,
cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Turn relative paths into absolute paths
let model_path = flags.model_path.and_then(|p| p.canonicalize().ok());
// Serve the model under the name provided, or the name of the GGUF file.
let model_name = flags.model_name.or_else(||
// "stem" means the filename without the extension.
model_path.as_ref()
.and_then(|p| p.file_stem())
.map(|n| n.to_string_lossy().into_owned()));
// Create the engine matching `out` // Create the engine matching `out`
let engine_config = match out_opt { let engine_config = match out_opt {
Output::EchoFull => EngineConfig::StaticFull { Output::EchoFull => {
service_name: flags.model_name, let Some(model_name) = model_name else {
anyhow::bail!(
"Pass --model-name or --model-path so we know which model to imitate"
);
};
EngineConfig::StaticFull {
service_name: model_name,
engine: output::echo_full::make_engine_full(), engine: output::echo_full::make_engine_full(),
}, }
}
#[cfg(feature = "mistralrs")]
Output::MistralRs => {
let Some(model_path) = model_path else {
anyhow::bail!("out=mistralrs requires flag --model-path=<full-path-to-model-gguf>");
};
if !model_path.is_file() {
anyhow::bail!("--model-path should refer to a GGUF file");
}
let Some(model_name) = model_name else {
unreachable!("We checked model_path earlier, and set model_name from model_path");
};
EngineConfig::StaticFull {
service_name: model_name,
engine: triton_llm::engines::mistralrs::make_engine(&model_path).await?,
}
}
}; };
match in_opt { match in_opt {
......
...@@ -20,10 +20,15 @@ use clap::Parser; ...@@ -20,10 +20,15 @@ use clap::Parser;
use triton_distributed::logging; use triton_distributed::logging;
const HELP: &str = r#" const HELP: &str = r#"
triton-llm service runner stub triton-llm service runner
Example:
- cargo build --release --features mistralrs,cuda
- ./target/release/tio in=text out=mistralrs --model-path Llama-3.2-1B-Instruct-Q4_K_M.gguf --model-name 'Llama-3.2-1B-Instruct'
"#; "#;
const USAGE: &str = "USAGE: tio in=[http|text] out=[echo_full] [--http-port 8080]"; const USAGE: &str = "USAGE: tio in=[http|text] out=[mistralrs|echo_full] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>]";
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
logging::init(); logging::init();
......
...@@ -48,6 +48,10 @@ impl fmt::Display for Input { ...@@ -48,6 +48,10 @@ impl fmt::Display for Input {
pub enum Output { pub enum Output {
/// Accept un-preprocessed requests, echo the prompt back as the response /// Accept un-preprocessed requests, echo the prompt back as the response
EchoFull, EchoFull,
#[cfg(feature = "mistralrs")]
/// Run inference on a model in a GGUF file using mistralrs w/ candle
MistralRs,
} }
impl TryFrom<&str> for Output { impl TryFrom<&str> for Output {
...@@ -55,6 +59,9 @@ impl TryFrom<&str> for Output { ...@@ -55,6 +59,9 @@ impl TryFrom<&str> for Output {
fn try_from(s: &str) -> anyhow::Result<Self> { fn try_from(s: &str) -> anyhow::Result<Self> {
match s { match s {
#[cfg(feature = "mistralrs")]
"mistralrs" => Ok(Output::MistralRs),
"echo_full" => Ok(Output::EchoFull), "echo_full" => Ok(Output::EchoFull),
e => Err(anyhow::anyhow!("Invalid out= option '{e}'")), e => Err(anyhow::anyhow!("Invalid out= option '{e}'")),
} }
...@@ -64,6 +71,9 @@ impl TryFrom<&str> for Output { ...@@ -64,6 +71,9 @@ impl TryFrom<&str> for Output {
impl fmt::Display for Output { impl fmt::Display for Output {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let s = match self { let s = match self {
#[cfg(feature = "mistralrs")]
Output::MistralRs => "mistralrs",
Output::EchoFull => "echo_full", Output::EchoFull => "echo_full",
}; };
write!(f, "{s}") write!(f, "{s}")
......
This diff is collapsed.
...@@ -21,6 +21,11 @@ authors.workspace = true ...@@ -21,6 +21,11 @@ authors.workspace = true
license.workspace = true license.workspace = true
homepage.workspace = true homepage.workspace = true
[features]
mistralrs = ["dep:mistralrs"]
metal = ["mistralrs/metal"]
cuda = ["mistralrs/cuda"]
[dependencies] [dependencies]
# repo # repo
...@@ -52,6 +57,10 @@ unicode-segmentation = "1.12" ...@@ -52,6 +57,10 @@ unicode-segmentation = "1.12"
axum = "0.8" axum = "0.8"
prometheus = { version = "0.13" } prometheus = { version = "0.13" }
# mistralrs
either = { version = "1.13" }
indexmap = { version = "2.6" }
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "c26e6c8", optional = true }
[dev-dependencies] [dev-dependencies]
insta = { version = "1.41", features = ["glob", "json", "redactions"]} insta = { version = "1.41", features = ["glob", "json", "redactions"]}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[cfg(feature = "mistralrs")]
pub mod mistralrs;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{cmp::min, num::NonZero, path::Path, sync::Arc};
use async_stream::stream;
use async_trait::async_trait;
use either::Either;
use indexmap::IndexMap;
use mistralrs::{
Constraint, DefaultSchedulerMethod, Device, DeviceMapMetadata, GGUFLoaderBuilder,
GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder, ModelDType, NormalRequest,
PagedAttentionConfig, Pipeline, Request, RequestMessage, ResponseOk, SamplingParams,
SchedulerConfig, TokenSource,
};
use tokio::sync::mpsc::channel;
use triton_distributed::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use triton_distributed::pipeline::error as pipeline_error;
use triton_distributed::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed::protocols::annotated::Annotated;
use crate::protocols::openai::chat_completions::{
ChatCompletionChoiceDelta, ChatCompletionContent, ChatCompletionRequest,
ChatCompletionResponseDelta, Content, FinishReason, MessageRole,
};
use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
/// If user does not provide a max_tokens limit prompt+output to this many
const DEFAULT_MAX_TOKENS: i32 = 8192;
pub async fn make_engine(
gguf_path: &Path,
) -> pipeline_error::Result<OpenAIChatCompletionsStreamingEngine> {
let engine = MistralRsEngine::new(gguf_path).await?;
let engine: OpenAIChatCompletionsStreamingEngine = Arc::new(engine);
Ok(engine)
}
/// Gets the best device, cpu, cuda if compiled with CUDA
fn best_device() -> pipeline_error::Result<Device> {
#[cfg(not(feature = "metal"))]
{
Ok(Device::cuda_if_available(0)?)
}
#[cfg(feature = "metal")]
{
Ok(Device::new_metal(0)?)
}
}
struct MistralRsEngine {
mistralrs: Arc<MistralRs>,
pipeline: Arc<tokio::sync::Mutex<dyn Pipeline + Send + Sync + 'static>>,
}
impl MistralRsEngine {
async fn new(model_path: &Path) -> pipeline_error::Result<Self> {
let Some(model_filename) = model_path.file_name() else {
pipeline_error::bail!("Missing filename in model path");
};
let Some(model_dir) = model_path.parent() else {
pipeline_error::bail!("Invalid model path");
};
// Select a Mistral model
// We do not use any files from HF servers here, and instead load the
// chat template from the specified file, and the tokenizer and model from a
// local GGUF file at the path `.`
let loader = GGUFLoaderBuilder::new(
None,
None,
model_dir.display().to_string(),
vec![model_filename.to_string_lossy().into_owned()],
GGUFSpecificConfig {
prompt_batchsize: None,
topology: None,
},
)
.build();
// Paged attention requires cuda
let paged_attention_config = if cfg!(feature = "cuda") {
Some(PagedAttentionConfig::new(
Some(32),
1024,
MemoryGpuConfig::Utilization(0.9),
)?)
} else {
None
};
// Load, into a Pipeline
let pipeline = loader.load_model_from_hf(
None,
TokenSource::CacheToken,
&ModelDType::Auto,
&best_device()?,
false,
DeviceMapMetadata::dummy(),
None,
paged_attention_config,
)?;
let scheduler = if cfg!(feature = "cuda") {
tracing::debug!("Using mistralrs PagedAttentionMeta scheduler");
let config = match pipeline.lock().await.get_metadata().cache_config.as_ref() {
Some(conf) => conf.clone(),
None => {
anyhow::bail!("Failed loading model config");
}
};
SchedulerConfig::PagedAttentionMeta {
max_num_seqs: 5,
config,
}
} else {
tracing::debug!("Using mistralrs DefaultScheduler");
SchedulerConfig::DefaultScheduler {
// Safety: unwrap trivially safe here
method: DefaultSchedulerMethod::Fixed(NonZero::new(5).unwrap()),
}
};
// Create the MistralRs, which is a runner
let builder = MistralRsBuilder::new(pipeline.clone(), scheduler);
Ok(MistralRsEngine {
mistralrs: builder.build(),
pipeline,
})
}
}
#[async_trait]
impl
AsyncEngine<
SingleIn<ChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
Error,
> for MistralRsEngine
{
async fn generate(
&self,
request: SingleIn<ChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> {
let (request, context) = request.transfer(());
let ctx = context.context();
let (tx, mut rx) = channel(10_000);
let maybe_tok = self.pipeline.lock().await.tokenizer();
let mut prompt_tokens = 0;
let mut messages = vec![];
for m in request.messages {
let content = match m.content {
Content::Text(prompt) => {
if let Some(tok) = maybe_tok.as_ref() {
prompt_tokens = tok
.encode(prompt.clone(), false)
.map(|e| e.len() as i32)
.unwrap_or(0);
}
prompt
}
Content::ImageUrl(_) => {
anyhow::bail!("Content::ImageUrl type is not supported");
}
};
let r = IndexMap::from([
("role".to_string(), Either::Left(m.role.to_string())),
("content".to_string(), Either::Left(content)),
]);
messages.push(r);
}
if messages.is_empty() {
anyhow::bail!("Empty request");
}
// TODO tracing::trace print the latest prompt, which should be the last message at user
// level.
//tracing::info!(prompt_tokens, "Received prompt");
let limit = DEFAULT_MAX_TOKENS - prompt_tokens;
let max_output_tokens = min(request.max_tokens.unwrap_or(limit), limit);
let mistralrs_request = Request::Normal(NormalRequest {
messages: RequestMessage::Chat(messages),
sampling_params: SamplingParams::deterministic(),
response: tx,
return_logprobs: false,
is_streaming: true,
id: 0,
constraint: Constraint::None,
suffix: None,
adapters: None,
tools: None,
tool_choice: None,
logits_processors: None,
return_raw_logits: false,
});
self.mistralrs.get_sender()?.send(mistralrs_request).await?;
let mut used_output_tokens = 0;
let output = stream! {
while let Some(response) = rx.recv().await {
let response = match response.as_result() {
Ok(r) => r,
Err(err) => {
tracing::error!(%err, "Failed converting mistralrs channel response to result.");
break;
}
};
match response {
ResponseOk::Chunk(c) => {
let from_assistant = c.choices[0].delta.content.clone();
if let Some(tok) = maybe_tok.as_ref() {
used_output_tokens += tok
.encode(from_assistant.clone(), false)
.map(|e| e.len() as i32)
.unwrap_or(0);
}
let finish_reason = match &c.choices[0].finish_reason {
Some(fr) => Some(fr.parse::<FinishReason>().unwrap_or(FinishReason::null)),
None if used_output_tokens >= max_output_tokens => {
tracing::debug!(used_output_tokens, max_output_tokens, "Met or exceed max_tokens. Stopping.");
Some(FinishReason::length)
}
None => None,
};
//tracing::trace!("from_assistant: {from_assistant}");
let delta = ChatCompletionResponseDelta{
id: c.id,
choices: vec![ChatCompletionChoiceDelta{
index: 0,
delta: ChatCompletionContent{
//role: c.choices[0].delta.role,
role: Some(MessageRole::assistant),
content: Some(from_assistant),
tool_calls: None,
},
logprobs: None,
finish_reason,
}],
model: c.model,
created: c.created as u64,
object: c.object.clone(),
usage: None,
system_fingerprint: Some(c.system_fingerprint),
service_tier: None,
};
let ann = Annotated{
id: None,
data: Some(delta),
event: None,
comment: None,
};
yield ann;
if finish_reason.is_some() {
//tracing::trace!("Finish reason: {finish_reason:?}");
break;
}
},
x => tracing::error!("Unhandled. {x:?}"),
}
}
};
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
...@@ -21,3 +21,4 @@ ...@@ -21,3 +21,4 @@
pub mod http; pub mod http;
pub mod protocols; pub mod protocols;
pub mod types; pub mod types;
pub mod engines;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment