"examples/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "6630fa5c4423af6df9c346d4fba8c9f18eeb6f99"
Commit 6e0cfbd9 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: vllm engine (#308)

triton-distributed-llm component and support in tio
parent 37a8ebaf
...@@ -27,6 +27,7 @@ sglang = ["triton-distributed-llm/sglang", "dep:netlink-packet-route", "dep:rtne ...@@ -27,6 +27,7 @@ sglang = ["triton-distributed-llm/sglang", "dep:netlink-packet-route", "dep:rtne
llamacpp = ["triton-distributed-llm/llamacpp"] llamacpp = ["triton-distributed-llm/llamacpp"]
cuda = ["triton-distributed-llm/cuda"] cuda = ["triton-distributed-llm/cuda"]
metal = ["triton-distributed-llm/metal"] metal = ["triton-distributed-llm/metal"]
vllm = ["triton-distributed-llm/vllm"]
[dependencies] [dependencies]
anyhow = "1" anyhow = "1"
......
...@@ -106,3 +106,36 @@ The extra `--model-config` flag is because: ...@@ -106,3 +106,36 @@ The extra `--model-config` flag is because:
- We don't yet read it out of the GGUF (TODO), so we need an HF repo with `tokenizer.json` et al - We don't yet read it out of the GGUF (TODO), so we need an HF repo with `tokenizer.json` et al
If the build step also builds llama_cpp libraries into `target/release` ("libllama.so", "libggml.so", "libggml-base.so", "libggml-cpu.so", "libggml-cuda.so"), then `tio` will need to find those at runtime. Set `LD_LIBRARY_PATH`, and be sure to deploy them alongside the `tio` binary. If the build step also builds llama_cpp libraries into `target/release` ("libllama.so", "libggml.so", "libggml-base.so", "libggml-cpu.so", "libggml-cuda.so"), then `tio` will need to find those at runtime. Set `LD_LIBRARY_PATH`, and be sure to deploy them alongside the `tio` binary.
## vllm
Using the [vllm](https://github.com/vllm-project/vllm) Python library. We only use the back half of vllm, talking to it over `zmq`. Slow startup, fast inference. Supports both safetensors from HF and GGUF files.
We use [uv](https://docs.astral.sh/uv/) but any virtualenv manager should work.
Setup:
```
uv venv
source .venv/bin/activate
uv pip install pip
uv pip install vllm setuptools
```
**Note: If you're on Ubuntu 22.04 or earlier, you will need to add `--python=python3.10` to your `uv venv` command**
Build:
```
cargo build --release --features vllm
```
Run (still inside that virtualenv) - HF repo:
```
./target/release/tio in=http out=vllm --model-path ~/llm_models/Llama-3.2-3B-Instruct/
```
Run (still inside that virtualenv) - GGUF:
```
./target/release/tio in=http out=vllm --model-path ~/llm_models/Llama-3.2-3B-Instruct-Q6_K.gguf --model-config ~/llm_models/Llama-3.2-3B-Instruct/
```
...@@ -107,6 +107,12 @@ pub struct Flags { ...@@ -107,6 +107,12 @@ pub struct Flags {
#[arg(long)] #[arg(long)]
pub dist_init_addr: Option<String>, pub dist_init_addr: Option<String>,
/// Internal use only.
// Start the python vllm engine sub-process.
#[arg(long)]
#[clap(hide = true, default_value = "false")]
pub internal_vllm_process: bool,
/// Internal use only. /// Internal use only.
/// Start the sglang Python sub-process. /// Start the sglang Python sub-process.
/// The params in the tuple are: /// The params in the tuple are:
...@@ -176,25 +182,37 @@ pub async fn run( ...@@ -176,25 +182,37 @@ pub async fn run(
.model_path_pos .model_path_pos
.or(flags.model_path_flag) .or(flags.model_path_flag)
.and_then(|p| p.canonicalize().ok()); .and_then(|p| p.canonicalize().ok());
// Serve the model under the name provided, or the name of the GGUF file. // Serve the model under the name provided, or the name of the GGUF file or HF repo.
let model_name = flags.model_name.or_else(|| { let model_name = flags.model_name.or_else(|| {
model_path model_path
.as_ref() .as_ref()
.and_then(|p| p.iter().last()) .and_then(|p| p.iter().last())
.map(|n| n.to_string_lossy().into_owned()) .map(|n| n.to_string_lossy().into_owned())
}); });
// If model path is a directory we can build a model deployment card from it // Load the model deployment card, if any
let maybe_card = match &model_path { // Only used by some engines, so without those feature flags it's unused.
Some(model_path) if model_path.is_dir() => { #[allow(unused_variables)]
ModelDeploymentCard::from_local_path(model_path, model_name.as_deref()) let (maybe_card_path, maybe_card) = match (&model_path, &flags.model_config) {
// --model-config takes precedence
(_, Some(model_config)) => {
let card = ModelDeploymentCard::from_local_path(model_config, model_name.as_deref())
.await .await
.ok() .ok();
(Some(model_config.clone()), card)
} }
Some(_) | None => None, // If --model-path is an HF repo use that
(Some(model_path), _) if model_path.is_dir() => {
let card = ModelDeploymentCard::from_local_path(model_path, model_name.as_deref())
.await
.ok();
(Some(model_path.clone()), card)
}
// Otherwise we don't have one, but we only need it if we're tokenizing
_ => (None, None),
}; };
#[cfg(feature = "sglang")] #[cfg(any(feature = "vllm", feature = "sglang"))]
let mut extra = None; // sglang sub-process let mut extra = None; // vllm and sglang sub-process
// Create the engine matching `out` // Create the engine matching `out`
let engine_config = match out_opt { let engine_config = match out_opt {
...@@ -304,6 +322,39 @@ pub async fn run( ...@@ -304,6 +322,39 @@ pub async fn run(
card: Box::new(card), card: Box::new(card),
} }
} }
#[cfg(feature = "vllm")]
Output::Vllm => {
use triton_distributed_llm::engines::vllm;
let Some(model_path) = model_path else {
anyhow::bail!(
"out=vllm requires flag --model-path=<full-path-to-hf-repo-or-model-gguf>"
);
};
let Some(card_path) = maybe_card_path else {
// If we have a gguf we also need a model card because we don't currently parse
// tokenizer et al out of gguf.
anyhow::bail!(
"Running GGUF files also requires a `--model-config` for the tokenizer et al."
);
};
let Some(card) = maybe_card.clone() else {
anyhow::bail!(
"out=vllm requires --model-path to be an HF repo, or for GGUF add flag --model-config <hf-repo>"
);
};
let Some(sock_prefix) = zmq_socket_prefix else {
anyhow::bail!("vllm requires zmq_socket_prefix");
};
let (engine, vllm_process) =
vllm::make_engine(cancel_token.clone(), &card_path, &model_path, &sock_prefix)
.await?;
extra = Some(vllm_process);
EngineConfig::StaticCore {
service_name: card.service_name.clone(),
engine,
card: Box::new(card),
}
}
#[cfg(feature = "llamacpp")] #[cfg(feature = "llamacpp")]
Output::LlamaCpp => { Output::LlamaCpp => {
use anyhow::Context; use anyhow::Context;
...@@ -314,25 +365,10 @@ pub async fn run( ...@@ -314,25 +365,10 @@ pub async fn run(
if !model_path.is_file() { if !model_path.is_file() {
anyhow::bail!("--model-path should refer to a GGUF file. llama_cpp does not support safetensors."); anyhow::bail!("--model-path should refer to a GGUF file. llama_cpp does not support safetensors.");
} }
let card = match flags.model_config { let Some(card) = maybe_card else {
None => { anyhow::bail!(
anyhow::bail!("Pass --model-config so we can find the tokenizer, should be an HF checkout."); "Pass --model-config so we can find the tokenizer, should be an HF checkout."
} );
Some(card_path) => {
if !card_path.is_dir() {
anyhow::bail!(
"--model-config should be a Hugging Face repo checkout directory."
);
}
ModelDeploymentCard::from_local_path(&card_path, model_name.as_deref())
.await
.with_context(|| {
format!(
"Failed loading ModelDeploymentCard from {}",
card_path.display()
)
})?
}
}; };
let engine = llamacpp::make_engine(cancel_token.clone(), &model_path).await?; let engine = llamacpp::make_engine(cancel_token.clone(), &model_path).await?;
EngineConfig::StaticCore { EngineConfig::StaticCore {
...@@ -361,9 +397,9 @@ pub async fn run( ...@@ -361,9 +397,9 @@ pub async fn run(
} }
} }
#[cfg(feature = "sglang")] #[cfg(any(feature = "vllm", feature = "sglang"))]
// Allow engines to ask main thread to wait on an extra future. // Allow engines to ask main thread to wait on an extra future.
// sglang uses this to shut down sub-process // vllm and sglang use this to shut down sub-process
if let Some(extra) = extra { if let Some(extra) = extra {
extra.await?; extra.await?;
} }
......
...@@ -41,7 +41,7 @@ const DEFAULT_OUT: Output = Output::EchoFull; ...@@ -41,7 +41,7 @@ const DEFAULT_OUT: Output = Output::EchoFull;
const ZMQ_SOCKET_PREFIX: &str = "tio"; const ZMQ_SOCKET_PREFIX: &str = "tio";
const USAGE: &str = "USAGE: tio in=[http|text|tdr://<path>|none] out=[mistralrs|sglang|llamacpp|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--dist-init-addr=127.0.0.1:9876] [--base-gpu-id=0]"; const USAGE: &str = "USAGE: tio in=[http|text|tdr://<path>|none] out=[mistralrs|sglang|llamacpp|vllm|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--dist-init-addr=127.0.0.1:9876] [--base-gpu-id=0]";
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
logging::init(); logging::init();
...@@ -83,6 +83,28 @@ fn main() -> anyhow::Result<()> { ...@@ -83,6 +83,28 @@ fn main() -> anyhow::Result<()> {
panic!("Rebuild with --features=sglang"); panic!("Rebuild with --features=sglang");
} }
} }
#[allow(unused_variables)]
if flags.internal_vllm_process {
let Some(model_path) = flags.model_path_flag else {
anyhow::bail!("vllm subprocess requires --model-path flag");
};
let Some(model_config) = flags.model_config else {
anyhow::bail!("vllm subprocess requires --model-config");
};
if !model_config.is_dir() {
anyhow::bail!("vllm subprocess requires model config path to be a directory containing tokenizer.json, config.json, etc");
}
if cfg!(feature = "vllm") {
#[cfg(feature = "vllm")]
{
use triton_distributed_llm::engines::vllm;
return vllm::run_subprocess(ZMQ_SOCKET_PREFIX, &model_config, &model_path);
}
} else {
panic!("Rebuild with --features=vllm");
}
}
} }
// max_worker_threads and max_blocking_threads from env vars or config file. // max_worker_threads and max_blocking_threads from env vars or config file.
......
...@@ -83,6 +83,10 @@ pub enum Output { ...@@ -83,6 +83,10 @@ pub enum Output {
#[cfg(feature = "llamacpp")] #[cfg(feature = "llamacpp")]
/// Run inference using llama.cpp /// Run inference using llama.cpp
LlamaCpp, LlamaCpp,
#[cfg(feature = "vllm")]
/// Run inference using vllm's engine
Vllm,
} }
impl TryFrom<&str> for Output { impl TryFrom<&str> for Output {
...@@ -99,6 +103,9 @@ impl TryFrom<&str> for Output { ...@@ -99,6 +103,9 @@ impl TryFrom<&str> for Output {
#[cfg(feature = "llamacpp")] #[cfg(feature = "llamacpp")]
"llamacpp" | "llama_cpp" => Ok(Output::LlamaCpp), "llamacpp" | "llama_cpp" => Ok(Output::LlamaCpp),
#[cfg(feature = "vllm")]
"vllm" => Ok(Output::Vllm),
"echo_full" => Ok(Output::EchoFull), "echo_full" => Ok(Output::EchoFull),
"echo_core" => Ok(Output::EchoCore), "echo_core" => Ok(Output::EchoCore),
...@@ -124,6 +131,9 @@ impl fmt::Display for Output { ...@@ -124,6 +131,9 @@ impl fmt::Display for Output {
#[cfg(feature = "llamacpp")] #[cfg(feature = "llamacpp")]
Output::LlamaCpp => "llamacpp", Output::LlamaCpp => "llamacpp",
#[cfg(feature = "vllm")]
Output::Vllm => "vllm",
Output::EchoFull => "echo_full", Output::EchoFull => "echo_full",
Output::EchoCore => "echo_core", Output::EchoCore => "echo_core",
......
...@@ -34,6 +34,7 @@ mistralrs = ["dep:mistralrs"] ...@@ -34,6 +34,7 @@ mistralrs = ["dep:mistralrs"]
llamacpp = ["dep:llama-cpp-2"] llamacpp = ["dep:llama-cpp-2"]
sglang = ["dep:async_zmq"] sglang = ["dep:async_zmq"]
sentencepiece = ["dep:sentencepiece"] sentencepiece = ["dep:sentencepiece"]
vllm = ["dep:async_zmq"]
cuda = ["mistralrs/cuda", "llama-cpp-2/cuda"] cuda = ["mistralrs/cuda", "llama-cpp-2/cuda"]
metal = ["mistralrs/metal", "llama-cpp-2/metal"] metal = ["mistralrs/metal", "llama-cpp-2/metal"]
......
...@@ -21,3 +21,6 @@ pub mod sglang; ...@@ -21,3 +21,6 @@ pub mod sglang;
#[cfg(feature = "llamacpp")] #[cfg(feature = "llamacpp")]
pub mod llamacpp; pub mod llamacpp;
#[cfg(feature = "vllm")]
pub mod vllm;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use std::sync::Arc;
use triton_distributed_runtime::pipeline::error as pipeline_error;
use triton_distributed_runtime::CancellationToken;
use crate::backend::ExecutionContext;
mod worker;
mod engine;
use engine::VllmEngine;
mod subprocess;
pub use subprocess::run_subprocess;
pub async fn make_engine(
cancel_token: CancellationToken,
// Where to find the tokenzier, and config.json
card_path: &Path,
// Full path to the model, either a GGUF file or an HF repo dir
model_path: &Path,
// Unique string to name zmq sockets
sock_code: &str,
) -> pipeline_error::Result<(ExecutionContext, tokio::task::JoinHandle<()>)> {
let mut engine = VllmEngine::new(cancel_token, sock_code, card_path, model_path).await?;
let vllm_process = engine.take_vllm_worker_handle();
let engine: ExecutionContext = Arc::new(engine);
Ok((engine, vllm_process))
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use async_stream::stream;
use async_trait::async_trait;
use crate::engines::vllm::worker;
use crate::protocols::common::llm_backend::{BackendInput, LLMEngineOutput};
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
use triton_distributed_runtime::runtime::CancellationToken;
pub struct VllmEngine {
cancel_token: CancellationToken,
worker: worker::VllmWorker,
}
impl VllmEngine {
pub async fn new(
cancel_token: CancellationToken,
sock_code: &str,
card_path: &Path,
model_path: &Path,
) -> anyhow::Result<Self> {
let w = worker::start(cancel_token.clone(), sock_code, card_path, model_path).await?;
let engine = VllmEngine {
cancel_token,
worker: w,
};
Ok(engine)
}
pub fn take_vllm_worker_handle(&mut self) -> tokio::task::JoinHandle<()> {
self.worker.take_vllm_handle()
}
}
#[async_trait]
impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for VllmEngine
{
async fn generate(
&self,
request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
let (request, context) = request.into_parts();
let ctx = context.context();
let request_id = ctx.id().to_string();
let (resp_tx, mut resp_rx) = tokio::sync::mpsc::channel(128);
let work_req = worker::WorkRequest {
request_id: context.id().to_string(),
request,
response_channel: resp_tx,
};
self.worker.enqueue_request(work_req).await?;
let cancel_token = self.cancel_token.clone();
let output = stream! {
loop {
let maybe_resp = tokio::select!{
_ = cancel_token.cancelled() => {
break;
}
maybe_resp = resp_rx.recv() => {
maybe_resp
}
};
match maybe_resp {
Some(out) => {
yield out;
},
None => {
tracing::trace!(request_id, "generate: response channel closed");
break;
}
}
}
};
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use pyo3::{types::IntoPyDict, Python};
use std::path::Path;
const PY_START_ENGINE: &std::ffi::CStr = cr#"
import multiprocessing
import signal
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.usage.usage_lib import UsageContext
engine_args = AsyncEngineArgs(model=f"{model_path}", served_model_name=None, tokenizer=f"{tokenizer_path}", task='generate', tokenizer_mode='auto', seed=0, max_model_len=8192, max_seq_len_to_capture=8192)
ipc_path = f"ipc:///tmp/{socket_id}";
engine_alive = multiprocessing.Value('b', True, lock=False)
run_mp_engine(engine_args, UsageContext.OPENAI_API_SERVER, ipc_path, engine_alive)
"#;
/// Start the Python vllm engine that listens on zmq socket
/// This is called by running `<bin> --internal-vllm-process
/// This does not return until vllm exits.
pub fn run_subprocess(
socket_id: &str,
model_card_path: &Path,
model_path: &Path,
) -> anyhow::Result<()> {
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
let card = model_card_path.display().to_string();
let model_path_str = model_path.display().to_string();
Python::with_gil(|py| {
let locals = [
("socket_id", socket_id),
("tokenizer_path", card.as_str()),
("model_path", model_path_str.as_str()),
]
.into_py_dict(py)
.unwrap();
if let Err(err) = py.run(PY_START_ENGINE, None, Some(&locals)) {
anyhow::bail!("vllm engine run error: {err}");
}
tracing::info!("vllm subprocess exit");
Ok(())
})
}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment