Commit 6e0cfbd9 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: vllm engine (#308)

triton-distributed-llm component and support in tio
parent 37a8ebaf
......@@ -27,6 +27,7 @@ sglang = ["triton-distributed-llm/sglang", "dep:netlink-packet-route", "dep:rtne
llamacpp = ["triton-distributed-llm/llamacpp"]
cuda = ["triton-distributed-llm/cuda"]
metal = ["triton-distributed-llm/metal"]
vllm = ["triton-distributed-llm/vllm"]
[dependencies]
anyhow = "1"
......
......@@ -106,3 +106,36 @@ The extra `--model-config` flag is because:
- We don't yet read it out of the GGUF (TODO), so we need an HF repo with `tokenizer.json` et al
If the build step also builds llama_cpp libraries into `target/release` ("libllama.so", "libggml.so", "libggml-base.so", "libggml-cpu.so", "libggml-cuda.so"), then `tio` will need to find those at runtime. Set `LD_LIBRARY_PATH`, and be sure to deploy them alongside the `tio` binary.
## vllm
Using the [vllm](https://github.com/vllm-project/vllm) Python library. We only use the back half of vllm, talking to it over `zmq`. Slow startup, fast inference. Supports both safetensors from HF and GGUF files.
We use [uv](https://docs.astral.sh/uv/) but any virtualenv manager should work.
Setup:
```
uv venv
source .venv/bin/activate
uv pip install pip
uv pip install vllm setuptools
```
**Note: If you're on Ubuntu 22.04 or earlier, you will need to add `--python=python3.10` to your `uv venv` command**
Build:
```
cargo build --release --features vllm
```
Run (still inside that virtualenv) - HF repo:
```
./target/release/tio in=http out=vllm --model-path ~/llm_models/Llama-3.2-3B-Instruct/
```
Run (still inside that virtualenv) - GGUF:
```
./target/release/tio in=http out=vllm --model-path ~/llm_models/Llama-3.2-3B-Instruct-Q6_K.gguf --model-config ~/llm_models/Llama-3.2-3B-Instruct/
```
......@@ -107,6 +107,12 @@ pub struct Flags {
#[arg(long)]
pub dist_init_addr: Option<String>,
/// Internal use only.
// Start the python vllm engine sub-process.
#[arg(long)]
#[clap(hide = true, default_value = "false")]
pub internal_vllm_process: bool,
/// Internal use only.
/// Start the sglang Python sub-process.
/// The params in the tuple are:
......@@ -176,25 +182,37 @@ pub async fn run(
.model_path_pos
.or(flags.model_path_flag)
.and_then(|p| p.canonicalize().ok());
// Serve the model under the name provided, or the name of the GGUF file.
// Serve the model under the name provided, or the name of the GGUF file or HF repo.
let model_name = flags.model_name.or_else(|| {
model_path
.as_ref()
.and_then(|p| p.iter().last())
.map(|n| n.to_string_lossy().into_owned())
});
// If model path is a directory we can build a model deployment card from it
let maybe_card = match &model_path {
Some(model_path) if model_path.is_dir() => {
ModelDeploymentCard::from_local_path(model_path, model_name.as_deref())
// Load the model deployment card, if any
// Only used by some engines, so without those feature flags it's unused.
#[allow(unused_variables)]
let (maybe_card_path, maybe_card) = match (&model_path, &flags.model_config) {
// --model-config takes precedence
(_, Some(model_config)) => {
let card = ModelDeploymentCard::from_local_path(model_config, model_name.as_deref())
.await
.ok();
(Some(model_config.clone()), card)
}
// If --model-path is an HF repo use that
(Some(model_path), _) if model_path.is_dir() => {
let card = ModelDeploymentCard::from_local_path(model_path, model_name.as_deref())
.await
.ok()
.ok();
(Some(model_path.clone()), card)
}
Some(_) | None => None,
// Otherwise we don't have one, but we only need it if we're tokenizing
_ => (None, None),
};
#[cfg(feature = "sglang")]
let mut extra = None; // sglang sub-process
#[cfg(any(feature = "vllm", feature = "sglang"))]
let mut extra = None; // vllm and sglang sub-process
// Create the engine matching `out`
let engine_config = match out_opt {
......@@ -304,6 +322,39 @@ pub async fn run(
card: Box::new(card),
}
}
#[cfg(feature = "vllm")]
Output::Vllm => {
use triton_distributed_llm::engines::vllm;
let Some(model_path) = model_path else {
anyhow::bail!(
"out=vllm requires flag --model-path=<full-path-to-hf-repo-or-model-gguf>"
);
};
let Some(card_path) = maybe_card_path else {
// If we have a gguf we also need a model card because we don't currently parse
// tokenizer et al out of gguf.
anyhow::bail!(
"Running GGUF files also requires a `--model-config` for the tokenizer et al."
);
};
let Some(card) = maybe_card.clone() else {
anyhow::bail!(
"out=vllm requires --model-path to be an HF repo, or for GGUF add flag --model-config <hf-repo>"
);
};
let Some(sock_prefix) = zmq_socket_prefix else {
anyhow::bail!("vllm requires zmq_socket_prefix");
};
let (engine, vllm_process) =
vllm::make_engine(cancel_token.clone(), &card_path, &model_path, &sock_prefix)
.await?;
extra = Some(vllm_process);
EngineConfig::StaticCore {
service_name: card.service_name.clone(),
engine,
card: Box::new(card),
}
}
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => {
use anyhow::Context;
......@@ -314,25 +365,10 @@ pub async fn run(
if !model_path.is_file() {
anyhow::bail!("--model-path should refer to a GGUF file. llama_cpp does not support safetensors.");
}
let card = match flags.model_config {
None => {
anyhow::bail!("Pass --model-config so we can find the tokenizer, should be an HF checkout.");
}
Some(card_path) => {
if !card_path.is_dir() {
let Some(card) = maybe_card else {
anyhow::bail!(
"--model-config should be a Hugging Face repo checkout directory."
"Pass --model-config so we can find the tokenizer, should be an HF checkout."
);
}
ModelDeploymentCard::from_local_path(&card_path, model_name.as_deref())
.await
.with_context(|| {
format!(
"Failed loading ModelDeploymentCard from {}",
card_path.display()
)
})?
}
};
let engine = llamacpp::make_engine(cancel_token.clone(), &model_path).await?;
EngineConfig::StaticCore {
......@@ -361,9 +397,9 @@ pub async fn run(
}
}
#[cfg(feature = "sglang")]
#[cfg(any(feature = "vllm", feature = "sglang"))]
// Allow engines to ask main thread to wait on an extra future.
// sglang uses this to shut down sub-process
// vllm and sglang use this to shut down sub-process
if let Some(extra) = extra {
extra.await?;
}
......
......@@ -41,7 +41,7 @@ const DEFAULT_OUT: Output = Output::EchoFull;
const ZMQ_SOCKET_PREFIX: &str = "tio";
const USAGE: &str = "USAGE: tio in=[http|text|tdr://<path>|none] out=[mistralrs|sglang|llamacpp|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--dist-init-addr=127.0.0.1:9876] [--base-gpu-id=0]";
const USAGE: &str = "USAGE: tio in=[http|text|tdr://<path>|none] out=[mistralrs|sglang|llamacpp|vllm|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--dist-init-addr=127.0.0.1:9876] [--base-gpu-id=0]";
fn main() -> anyhow::Result<()> {
logging::init();
......@@ -83,6 +83,28 @@ fn main() -> anyhow::Result<()> {
panic!("Rebuild with --features=sglang");
}
}
#[allow(unused_variables)]
if flags.internal_vllm_process {
let Some(model_path) = flags.model_path_flag else {
anyhow::bail!("vllm subprocess requires --model-path flag");
};
let Some(model_config) = flags.model_config else {
anyhow::bail!("vllm subprocess requires --model-config");
};
if !model_config.is_dir() {
anyhow::bail!("vllm subprocess requires model config path to be a directory containing tokenizer.json, config.json, etc");
}
if cfg!(feature = "vllm") {
#[cfg(feature = "vllm")]
{
use triton_distributed_llm::engines::vllm;
return vllm::run_subprocess(ZMQ_SOCKET_PREFIX, &model_config, &model_path);
}
} else {
panic!("Rebuild with --features=vllm");
}
}
}
// max_worker_threads and max_blocking_threads from env vars or config file.
......
......@@ -83,6 +83,10 @@ pub enum Output {
#[cfg(feature = "llamacpp")]
/// Run inference using llama.cpp
LlamaCpp,
#[cfg(feature = "vllm")]
/// Run inference using vllm's engine
Vllm,
}
impl TryFrom<&str> for Output {
......@@ -99,6 +103,9 @@ impl TryFrom<&str> for Output {
#[cfg(feature = "llamacpp")]
"llamacpp" | "llama_cpp" => Ok(Output::LlamaCpp),
#[cfg(feature = "vllm")]
"vllm" => Ok(Output::Vllm),
"echo_full" => Ok(Output::EchoFull),
"echo_core" => Ok(Output::EchoCore),
......@@ -124,6 +131,9 @@ impl fmt::Display for Output {
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => "llamacpp",
#[cfg(feature = "vllm")]
Output::Vllm => "vllm",
Output::EchoFull => "echo_full",
Output::EchoCore => "echo_core",
......
......@@ -34,6 +34,7 @@ mistralrs = ["dep:mistralrs"]
llamacpp = ["dep:llama-cpp-2"]
sglang = ["dep:async_zmq"]
sentencepiece = ["dep:sentencepiece"]
vllm = ["dep:async_zmq"]
cuda = ["mistralrs/cuda", "llama-cpp-2/cuda"]
metal = ["mistralrs/metal", "llama-cpp-2/metal"]
......
......@@ -21,3 +21,6 @@ pub mod sglang;
#[cfg(feature = "llamacpp")]
pub mod llamacpp;
#[cfg(feature = "vllm")]
pub mod vllm;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use std::sync::Arc;
use triton_distributed_runtime::pipeline::error as pipeline_error;
use triton_distributed_runtime::CancellationToken;
use crate::backend::ExecutionContext;
mod worker;
mod engine;
use engine::VllmEngine;
mod subprocess;
pub use subprocess::run_subprocess;
pub async fn make_engine(
cancel_token: CancellationToken,
// Where to find the tokenzier, and config.json
card_path: &Path,
// Full path to the model, either a GGUF file or an HF repo dir
model_path: &Path,
// Unique string to name zmq sockets
sock_code: &str,
) -> pipeline_error::Result<(ExecutionContext, tokio::task::JoinHandle<()>)> {
let mut engine = VllmEngine::new(cancel_token, sock_code, card_path, model_path).await?;
let vllm_process = engine.take_vllm_worker_handle();
let engine: ExecutionContext = Arc::new(engine);
Ok((engine, vllm_process))
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use async_stream::stream;
use async_trait::async_trait;
use crate::engines::vllm::worker;
use crate::protocols::common::llm_backend::{BackendInput, LLMEngineOutput};
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
use triton_distributed_runtime::runtime::CancellationToken;
pub struct VllmEngine {
cancel_token: CancellationToken,
worker: worker::VllmWorker,
}
impl VllmEngine {
pub async fn new(
cancel_token: CancellationToken,
sock_code: &str,
card_path: &Path,
model_path: &Path,
) -> anyhow::Result<Self> {
let w = worker::start(cancel_token.clone(), sock_code, card_path, model_path).await?;
let engine = VllmEngine {
cancel_token,
worker: w,
};
Ok(engine)
}
pub fn take_vllm_worker_handle(&mut self) -> tokio::task::JoinHandle<()> {
self.worker.take_vllm_handle()
}
}
#[async_trait]
impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for VllmEngine
{
async fn generate(
&self,
request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
let (request, context) = request.into_parts();
let ctx = context.context();
let request_id = ctx.id().to_string();
let (resp_tx, mut resp_rx) = tokio::sync::mpsc::channel(128);
let work_req = worker::WorkRequest {
request_id: context.id().to_string(),
request,
response_channel: resp_tx,
};
self.worker.enqueue_request(work_req).await?;
let cancel_token = self.cancel_token.clone();
let output = stream! {
loop {
let maybe_resp = tokio::select!{
_ = cancel_token.cancelled() => {
break;
}
maybe_resp = resp_rx.recv() => {
maybe_resp
}
};
match maybe_resp {
Some(out) => {
yield out;
},
None => {
tracing::trace!(request_id, "generate: response channel closed");
break;
}
}
}
};
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use pyo3::{types::IntoPyDict, Python};
use std::path::Path;
const PY_START_ENGINE: &std::ffi::CStr = cr#"
import multiprocessing
import signal
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.usage.usage_lib import UsageContext
engine_args = AsyncEngineArgs(model=f"{model_path}", served_model_name=None, tokenizer=f"{tokenizer_path}", task='generate', tokenizer_mode='auto', seed=0, max_model_len=8192, max_seq_len_to_capture=8192)
ipc_path = f"ipc:///tmp/{socket_id}";
engine_alive = multiprocessing.Value('b', True, lock=False)
run_mp_engine(engine_args, UsageContext.OPENAI_API_SERVER, ipc_path, engine_alive)
"#;
/// Start the Python vllm engine that listens on zmq socket
/// This is called by running `<bin> --internal-vllm-process
/// This does not return until vllm exits.
pub fn run_subprocess(
socket_id: &str,
model_card_path: &Path,
model_path: &Path,
) -> anyhow::Result<()> {
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
let card = model_card_path.display().to_string();
let model_path_str = model_path.display().to_string();
Python::with_gil(|py| {
let locals = [
("socket_id", socket_id),
("tokenizer_path", card.as_str()),
("model_path", model_path_str.as_str()),
]
.into_py_dict(py)
.unwrap();
if let Err(err) = py.run(PY_START_ENGINE, None, Some(&locals)) {
anyhow::bail!("vllm engine run error: {err}");
}
tracing::info!("vllm subprocess exit");
Ok(())
})
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
collections::HashMap, ops::Deref, path::Path, process::Stdio, sync::Arc, time::Duration,
vec::IntoIter,
};
use async_zmq::{SinkExt, StreamExt};
use pyo3::{
prelude::*,
types::{IntoPyDict, PyBytes, PyString},
};
use tokio::sync::mpsc::Sender;
use tokio::task::JoinHandle;
use tokio::{io::AsyncBufReadExt, sync::mpsc::error::SendError};
use triton_distributed_runtime::protocols::annotated::Annotated;
use triton_distributed_runtime::CancellationToken;
use crate::protocols::common::llm_backend::LLMEngineOutput;
use crate::protocols::common::preprocessor::PreprocessedRequest;
use crate::protocols::common::FinishReason;
/// If user does not provide a max_tokens limit to this many
const DEFAULT_MAX_TOKENS: u32 = 8192;
/// Wait this long for the vllm sub-process to stop after we send it a KILL
const VLLM_STOP_TIMEOUT: Duration = Duration::from_millis(1500);
type RequestID = String;
pub struct VllmWorker {
/// How we receive work requests
tx: Sender<WorkRequest>,
/// Handle of the task that reads from `tx` and forwards those requests over zmq to vllm
_input_loop: JoinHandle<()>,
/// Handle of the task that reads vllm's responses from zmq and dispatches them to the correct
/// active request.
_output_loop: JoinHandle<()>,
/// Handle of the vllm background process
vllm: Option<JoinHandle<()>>,
// We don't need to hold on to this, it's already shared between input_loop and output_loop
// But later we'll probably want stats - how many active requests etc, so keep it here
_active_requests: Arc<tokio::sync::Mutex<HashMap<RequestID, ActiveRequest>>>,
// Need to keep this alive
// TODO: With async_zmq we possibly don't need this at all
#[allow(dead_code)]
zmq_context: async_zmq::Context,
}
/// How we get asked to do some work. These get unpacked and forwarded to vllm.
pub struct WorkRequest {
pub request: PreprocessedRequest,
pub request_id: RequestID,
pub response_channel: Sender<Annotated<LLMEngineOutput>>,
}
/// A request currently being process by vllm
struct ActiveRequest {
tx: Sender<Annotated<LLMEngineOutput>>,
num_output_tokens_so_far: usize,
max_tokens: usize,
}
/// Python imports
struct Imports {
pickle_module: PyObject,
tokens_prompt_type: PyObject,
sample_params_type: PyObject,
rpc_type: PyObject,
startup_type: PyObject,
}
/// All the zmq sockets we used. This object only used to passing them around to avoid large
/// tuples.
struct Sockets {
#[allow(dead_code)]
context: async_zmq::Context, // we have to keep this alive
// Control socket, how we ask vllm engine to start.
// Not the best name, but this is what vllm calls it internally.
data: async_zmq::Dealer<IntoIter<Vec<u8>>, Vec<u8>>,
// Requests from us to the vllm engine
input: async_zmq::Push<IntoIter<Vec<u8>>, Vec<u8>>,
// Responses from the vllm engine back to us
output: async_zmq::Pull,
// Heartbeat messages from vllm process
heartbeat: async_zmq::Pull,
}
/// The message vllm sends us over zmq when it's ready to work.
#[derive(FromPyObject, Debug)]
struct RPCStartupResponse {
#[allow(dead_code)]
tracing_enabled: bool,
}
/// What vllm sends us. Usually it contains a single token.
#[allow(dead_code)]
#[derive(FromPyObject, Debug)]
pub struct RequestOutput {
request_id: String,
prompt: Option<String>,
prompt_token_ids: Option<Vec<u32>>,
prompt_logprobs: Option<Vec<Option<HashMap<u32, Logprob>>>>,
outputs: Vec<CompletionOutput>,
finished: bool,
//metrics: Optional[RequestMetrics] = None,
//lora_request: Optional[LoRARequest] = None,
encoder_prompt: Option<String>,
encoder_prompt_token_ids: Option<Vec<u32>>,
num_cached_tokens: Option<u32>,
}
#[allow(dead_code)]
#[derive(FromPyObject, Debug)]
pub struct CompletionOutput {
index: u32,
text: String,
token_ids: Vec<u32>,
cumulative_logprob: Option<f32>,
logprobs: Option<Vec<HashMap<u32, Logprob>>>,
finish_reason: Option<String>,
//stop_reason: Union[int, str, None] = None
//lora_request: Optional[LoRARequest] = None
}
#[allow(dead_code)]
#[derive(FromPyObject, Debug)]
struct Logprob {
logprob: f32,
rank: Option<u32>,
decoded_token: Option<String>,
}
/// Main entry point
pub async fn start(
cancel_token: CancellationToken,
sock_code: &str,
card_path: &Path,
model_path: &Path,
) -> anyhow::Result<VllmWorker> {
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
let py_imports = Arc::new(python_imports());
let Sockets {
context,
data,
input,
output,
heartbeat,
} = zmq_sockets(sock_code)?;
let vllm_process = start_vllm(card_path, model_path, &py_imports, data).await?;
let vllm_join_handle = watch_vllm(cancel_token.clone(), vllm_process);
tokio::spawn(heartbeat_loop(cancel_token.clone(), heartbeat));
let active_requests = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
let (tx, rx) = tokio::sync::mpsc::channel(8);
let input_loop_handle = {
let cancel_token = cancel_token.clone();
let py_imports = py_imports.clone();
let active_requests = active_requests.clone();
tokio::spawn(input_loop(
cancel_token,
py_imports,
input,
active_requests,
rx,
))
};
let output_loop_handle = {
let cancel_token = cancel_token.clone();
let py_imports = py_imports.clone();
let active_requests = active_requests.clone();
tokio::spawn(output_loop(
cancel_token,
py_imports,
output,
active_requests,
))
};
Ok(VllmWorker {
tx,
zmq_context: context,
_input_loop: input_loop_handle,
_output_loop: output_loop_handle,
vllm: Some(vllm_join_handle),
_active_requests: active_requests,
})
}
/// Import all the python packages we'll need. `vllm` particularly takes a few seconds.
fn python_imports() -> Imports {
Python::with_gil(|py| {
let pickle_module: PyObject = match py.import("pickle") {
Ok(m) => m.into(),
Err(err) => {
// There is no vllm without python
panic!("Failed to import python 'pickle' module. Is Python installed? {err}");
}
};
let vllm_module: PyObject = match py.import("vllm") {
Ok(m) => m.into(),
Err(err) => {
panic!("Failed to import python 'vllm' module. Are we running in the correct venv? {err}");
}
};
let tokens_prompt_type: PyObject = vllm_module.getattr(py, "TokensPrompt").unwrap();
let sample_params_type: PyObject = vllm_module.getattr(py, "SamplingParams").unwrap();
let mod_multiprocessing = py.import("vllm.engine.multiprocessing").unwrap();
let rpc_type: PyObject = mod_multiprocessing
.getattr("RPCProcessRequest")
.unwrap()
.into();
let startup_type: PyObject = mod_multiprocessing
.getattr("RPCStartupRequest")
.unwrap()
.into();
Imports {
pickle_module,
tokens_prompt_type,
sample_params_type,
rpc_type,
startup_type,
}
})
}
/// Create all the zmq sockets we're going to use.
fn zmq_sockets(sock_code: &str) -> anyhow::Result<Sockets> {
let zmq_context = async_zmq::Context::new();
let input = async_zmq::push(&format!("ipc:///tmp/{sock_code}_input_socket"))?
.with_context(&zmq_context)
.connect()?;
let output = async_zmq::pull(&format!("ipc:///tmp/{sock_code}_output_socket"))?
.with_context(&zmq_context)
.connect()?;
let data = async_zmq::dealer(&format!("ipc:///tmp/{sock_code}_data_socket"))?
.with_context(&zmq_context)
.connect()?;
let heartbeat = async_zmq::pull(&format!("ipc:///tmp/{sock_code}_health_socket"))?
.with_context(&zmq_context)
.connect()?;
Ok(Sockets {
context: zmq_context,
data,
input,
output,
heartbeat,
})
}
/// Start the vllm python sub-process and wait for it to start
async fn start_vllm(
card_path: &Path,
model_path: &Path,
python_imports: &Imports,
mut data_socket: async_zmq::Dealer<IntoIter<Vec<u8>>, Vec<u8>>,
) -> anyhow::Result<tokio::process::Child> {
// The in/out args are not used but we currently require them for parsing cli args
let vllm_args = [
"--internal-vllm-process",
&format!("--model-config={}", card_path.display()),
&format!("--model-path={}", model_path.display()),
];
let self_path = std::env::current_exe()?;
let mut proc = tokio::process::Command::new(self_path)
.env("VLLM_LOGGING_LEVEL", "DEBUG")
.args(vllm_args)
.kill_on_drop(false)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let stdout = tokio::io::BufReader::new(proc.stdout.take().unwrap());
let stderr = tokio::io::BufReader::new(proc.stderr.take().unwrap());
tokio::spawn(async move {
let mut lines = stdout.lines();
while let Ok(Some(line)) = lines.next_line().await {
let mut line_parts = line.splitn(4, ' ');
let log_level = line_parts.next().unwrap_or_default();
// Skip date (0) and time (1). Print last (2) which is everything else.
let line = line_parts.nth(2).unwrap_or_default();
if line.starts_with("custom_op.py:68") {
// Skip a noisy line
// custom_op.py:68] custom op <the op> enabled
continue;
}
match log_level {
"DEBUG" => tracing::debug!("VLLM: {line}"),
"INFO" => tracing::info!("VLLM: {line}"),
"WARNING" => tracing::warn!("VLLM: {line}"),
level => tracing::info!("VLLM: {level} {line}"),
}
}
});
tokio::spawn(async move {
let mut lines = stderr.lines();
while let Ok(Some(line)) = lines.next_line().await {
tracing::warn!("VLLM: {line}");
}
});
let start_req_bytes: Vec<u8> = Python::with_gil(|py| {
let start_req = python_imports
.startup_type
.getattr(py, "IS_SERVER_READY")
.unwrap();
let pickle_dumps = python_imports.pickle_module.getattr(py, "dumps").unwrap();
pickle_dumps
.call1(py, (start_req,))
.unwrap()
.extract(py)
.unwrap()
});
data_socket.send(vec![start_req_bytes].into()).await?;
let start_resp: Vec<u8> = match data_socket.next().await {
Some(Ok(r)) => {
if !r.is_empty() {
r[0].deref().to_vec()
} else {
anyhow::bail!("vllm failed to start. No response on dealer/data socket");
}
}
Some(Err(err)) => {
anyhow::bail!("vllm failed to start. Error reading from dealer/data socket: {err}");
}
None => {
anyhow::bail!("vllm failed to start. dealer/data socket is closed.");
}
};
let resp: RPCStartupResponse = Python::with_gil(|py| {
let pickle_loads = python_imports.pickle_module.getattr(py, "loads").unwrap();
pickle_loads
.call1(py, (start_resp,))
.unwrap()
.extract(py)
.unwrap()
});
tracing::info!("vllm zmq backend is ready: {resp:?}");
Ok(proc)
}
// Stop the vllm process when we stop, and prevent it going zombie.
fn watch_vllm(
cancel_token: CancellationToken,
mut vllm_process: tokio::process::Child,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
cancel_token.cancelled().await;
tokio::select! {
_ = vllm_process.wait() => {
return;
},
_ = tokio::time::sleep(VLLM_STOP_TIMEOUT) => { }
}
if let Err(err) = vllm_process.start_kill() {
tracing::error!("Failing killing vllm subprocess: {err}");
return;
}
tokio::select! {
_ = vllm_process.wait() => { },
_ = tokio::time::sleep(VLLM_STOP_TIMEOUT) => {
tracing::warn!("Timeout waiting for vllm sub-process to stop after kill");
}
}
})
}
// How we know vllm engine is alive. It sends "SUCCESS" as a pickled string every 10s.
// Runs outside of tokio on a regular thread.
// TODO: If we don't get heartbeats we should, euh, do something. vllm is gone. At least
// de-register the model.
async fn heartbeat_loop(cancel_token: CancellationToken, mut socket: async_zmq::Pull) {
loop {
let maybe_hb = tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
maybe_hb = socket.next() => {
maybe_hb
}
};
let b = match maybe_hb {
Some(Ok(b)) => b[0].deref().to_vec(),
Some(Err(err)) => {
tracing::error!("Error reading from vllm heartbeat socket: {err}");
break;
}
None => {
tracing::debug!("vllm heartbeat socket closed");
break;
}
};
let s: String = match serde_pickle::from_slice(&b, Default::default()) {
Ok(s) => s,
Err(err) => {
tracing::error!("Error de-serializing vllm heartbeat response. It was probably Exception not str. {err}");
break;
}
};
if s != "SUCCESS" {
tracing::error!("vllm heartbeat error, expected 'SUCCESS' got '{s}'");
break;
}
}
}
fn from_vllm(output: CompletionOutput, previous_total_toks: usize) -> LLMEngineOutput {
let finish_reason = match output.finish_reason.as_deref() {
Some("stop") => Some(FinishReason::Stop),
Some("abort") => Some(FinishReason::Cancelled),
Some("length") => Some(FinishReason::Length),
Some(unknown) => {
tracing::info!("Unknown vllm stop reason '{unknown}'. Please add to vllm.rs");
Some(FinishReason::Stop)
}
None => None,
};
LLMEngineOutput {
// todo - propagate mdcsum
token_ids: output.token_ids[previous_total_toks..].into(),
tokens: None,
text: None,
//text: if output.text.is_empty() { None } else { Some(output.text) },
cum_log_probs: output.cumulative_logprob.map(|v| v as f64),
log_probs: None, // TODO output.logprobs
finish_reason,
}
}
async fn input_loop(
cancel_token: CancellationToken,
py_imports: Arc<Imports>,
mut input_socket: async_zmq::Push<IntoIter<Vec<u8>>, Vec<u8>>,
active_requests: Arc<tokio::sync::Mutex<HashMap<RequestID, ActiveRequest>>>,
mut rx: tokio::sync::mpsc::Receiver<WorkRequest>,
) {
loop {
let work_request = tokio::select! {
_ = cancel_token.cancelled() => {
tracing::trace!("VllmWorker.input_loop exit");
break;
}
req = rx.recv() => {
match req {
Some(req) => req,
None => {
tracing::trace!("VllmWorker input_loop socket closed");
break;
}
}
}
};
let request_id = work_request.request_id;
let token_ids = work_request.request.token_ids.clone();
let temperature: f64 = work_request
.request
.sampling_options
.temperature
.unwrap_or(0.0)
.into();
let max_tokens = work_request
.request
.stop_conditions
.max_tokens
.unwrap_or(DEFAULT_MAX_TOKENS);
// Parts that don't change
let (py_request_id, sampling_params) = Python::with_gil(|py| {
let py_temp: PyObject = temperature.into_pyobject(py).unwrap().into();
let py_max_tokens: PyObject = max_tokens.into_pyobject(py).unwrap().into();
let sp_kwargs = [("temperature", py_temp), ("max_tokens", py_max_tokens)]
.into_py_dict(py)
.unwrap();
let sampling_params = py_imports
.sample_params_type
.call(py, (), Some(&sp_kwargs))
.unwrap();
let py_request_id: PyObject = PyString::new(py, &request_id).into();
(py_request_id, sampling_params)
});
let pickled_req: Vec<u8> = Python::with_gil(|py| {
let token_prompt_kwargs = [("prompt_token_ids", token_ids.clone())]
.into_py_dict(py)
.unwrap();
let prompt_obj = py_imports
.tokens_prompt_type
.call(py, (), Some(&token_prompt_kwargs))
.unwrap();
let rpc_kwargs = [
("prompt", prompt_obj),
("params", sampling_params.clone()),
("request_id", py_request_id.clone()),
]
.into_py_dict(py)
.unwrap();
let req = py_imports.rpc_type.call(py, (), Some(&rpc_kwargs)).unwrap();
let pickle_dumps = py_imports.pickle_module.getattr(py, "dumps").unwrap();
pickle_dumps.call1(py, (req,)).unwrap().extract(py).unwrap()
});
let new_active_request = ActiveRequest {
tx: work_request.response_channel,
max_tokens: max_tokens as usize,
num_output_tokens_so_far: 0,
};
active_requests
.lock()
.await
.insert(request_id, new_active_request);
if let Err(err) = input_socket.send(vec![pickled_req].into()).await {
tracing::error!("Error sending new request to vllm over zmq: {err}");
}
}
}
/// Read from vllm's output zmq socket, find which request it is for and forward over that channel.
async fn output_loop(
cancel_token: CancellationToken,
py_imports: Arc<Imports>,
mut output_socket: async_zmq::Pull,
active_requests: Arc<tokio::sync::Mutex<HashMap<RequestID, ActiveRequest>>>,
) {
loop {
let mut bb = tokio::select! {
_ = cancel_token.cancelled() => {
tracing::trace!("VllmWorker.output_loop exit");
break;
}
from_vllm = output_socket.next() => {
match from_vllm {
Some(Ok(b)) => b,
Some(Err(err)) => {
tracing::error!("Error reading from vllm zmq output: {err}");
continue; // hope lives eternal
}
None => {
tracing::debug!("zmq output socket closed");
break;
}
}
}
};
let frame = bb.remove(0);
let mut reqs_out: Vec<RequestOutput> = Python::with_gil(|py| {
let pickle_loads = py_imports.pickle_module.getattr(py, "loads").unwrap();
let frame_bytes = PyBytes::new(py, &frame);
pickle_loads
.call1(py, (frame_bytes,))
.unwrap()
.extract(py)
.unwrap()
});
if reqs_out.is_empty() {
tracing::debug!("Received message from vllm with no content");
continue;
}
let req_out = reqs_out.remove(0);
if req_out.finished {
// The last token is the eos_token, don't forward it
let out = Annotated::from_data(LLMEngineOutput::stop());
let maybe_active = active_requests.lock().await.remove(&req_out.request_id);
match maybe_active {
Some(active) => {
let _ = active.tx.send(out).await;
}
None => {
tracing::warn!(
req_out.request_id,
"Missing active request to notify of stop"
);
}
}
continue;
}
let mut remove_after = false;
for vllm_output in req_out.outputs.into_iter() {
let next_total_toks = vllm_output.token_ids.len();
match active_requests.lock().await.get_mut(&req_out.request_id) {
Some(active) => {
let out = from_vllm(vllm_output, active.num_output_tokens_so_far);
active.num_output_tokens_so_far = next_total_toks;
let out = if active.num_output_tokens_so_far <= active.max_tokens {
Annotated::from_data(out)
} else {
// we exceeded max tokens, this request is over
remove_after = true;
Annotated::from_data(LLMEngineOutput::length())
};
let _ = active.tx.send(out).await;
}
None => {
tracing::warn!(req_out.request_id, "Missing active request");
}
}
}
if remove_after {
let _ = active_requests.lock().await.remove(&req_out.request_id);
}
}
}
impl VllmWorker {
/// Send a request to vllm
pub async fn enqueue_request(&self, r: WorkRequest) -> Result<(), SendError<WorkRequest>> {
self.tx.send(r).await
}
/// Get the vllm sub-process handle, so we can await it and prevent it going zombie.
pub fn take_vllm_handle(&mut self) -> JoinHandle<()> {
self.vllm.take().unwrap()
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment