Commit a657ec61 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: vllm engine tensor parallel and pipeline parallel (#16)

Needs more testing but good enough for now. I get the same results with this as with `vllm serve`.
parent a32cdad6
......@@ -24,11 +24,11 @@ license = "Apache-2.0"
[features]
mistralrs = ["triton-distributed-llm/mistralrs"]
sglang = ["triton-distributed-llm/sglang", "dep:netlink-packet-route", "dep:rtnetlink"]
vllm = ["triton-distributed-llm/vllm", "dep:netlink-packet-route", "dep:rtnetlink"]
llamacpp = ["triton-distributed-llm/llamacpp"]
trtllm = ["triton-distributed-llm/trtllm"]
cuda = ["triton-distributed-llm/cuda"]
metal = ["triton-distributed-llm/metal"]
vllm = ["triton-distributed-llm/vllm"]
[dependencies]
anyhow = "1"
......
......@@ -118,7 +118,7 @@ Setup:
uv venv
source .venv/bin/activate
uv pip install pip
uv pip install vllm setuptools
uv pip install vllm==0.7.3 setuptools
```
**Note: If you're on Ubuntu 22.04 or earlier, you will need to add `--python=python3.10` to your `uv venv` command**
......@@ -139,6 +139,19 @@ Run (still inside that virtualenv) - GGUF:
./target/release/tio in=http out=vllm --model-path ~/llm_models/Llama-3.2-3B-Instruct-Q6_K.gguf --model-config ~/llm_models/Llama-3.2-3B-Instruct/
```
+ Multi-node:
Node 1:
```
tio in=text out=vllm ~/llm_models/Llama-3.2-3B-Instruct/ --tensor-parallel-size 8 --num-nodes 2 --leader-addr 10.217.98.122:6539 --node-rank 0
```
Node 2:
```
tio in=none out=vllm ~/llm_models/Llama-3.2-3B-Instruct/ --num-nodes 2 --leader-addr 10.217.98.122:6539 --node-rank 1
```
## trtllm
TensorRT-LLM. Requires `clang` and `libclang-dev`.
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::PathBuf;
use std::str::FromStr;
/// Required options depend on the in and out choices
#[derive(clap::Parser, Debug, Clone)]
#[command(version, about, long_about = None)]
pub struct Flags {
/// Full path to the model, which can be either a GGUF file or a checked out HF repository.
/// For the `echo_full` engine omit the flag.
#[arg(index = 1)]
pub model_path_pos: Option<PathBuf>,
// `--model-path`. The one above is `tio <positional-model-path>`
#[arg(long = "model-path")]
pub model_path_flag: Option<PathBuf>,
/// HTTP port. `in=http` only
#[arg(long, default_value = "8080")]
pub http_port: u16,
/// The name of the model we are serving
#[arg(long)]
pub model_name: Option<String>,
/// llamacpp only
///
/// The path to the tokenizer and model config because:
/// - llama_cpp only runs GGUF files
/// - our engine is a 'core' engine in that we do the tokenization, so we need the vocab
/// - TODO: we don't yet extract that from the GGUF. Once we do we can remove this flag.
#[arg(long)]
pub model_config: Option<PathBuf>,
/// sglang, vllm, trtllm
///
/// How many GPUs to use at once, total across all nodes.
/// This must divide by num_nodes, and each node must use the same number of GPUs.
#[arg(long, default_value = "1", value_parser = clap::value_parser!(u32).range(1..256))]
pub tensor_parallel_size: u32,
/// sglang only
/// vllm uses CUDA_VISIBLE_DEVICES env var
///
/// Use GPUs from this ID upwards.
/// If your machine has four GPUs but the first two (0 and 1) are in use,
/// pass --base-gpu-id 2 to use the third GPU (and up, if tensor_parallel_size > 1)
#[arg(long, default_value = "0", value_parser = clap::value_parser!(u32).range(0..256))]
pub base_gpu_id: u32,
/// vllm and sglang only
///
/// How many nodes/hosts to use
#[arg(long, default_value = "1", value_parser = clap::value_parser!(u32).range(1..256))]
pub num_nodes: u32,
/// vllm and sglang only
///
/// This nodes' unique ID, running from 0 to num_nodes.
#[arg(long, default_value = "0", value_parser = clap::value_parser!(u32).range(0..255))]
pub node_rank: u32,
/// For multi-node / pipeline parallel this is the <host>:<port> of the first node.
///
/// - vllm: The address/port of the Ray head node.
///
/// - sglang: The Torch Distributed init method address, in format <host>:<port>.
/// It becomes "tcp://<host>:<port>" when given to torch.distributed.init_process_group.
/// This expects to use the nccl backend (transparently to us here).
/// All nodes must use the same address here, which is node_rank == 0's address.
///
#[arg(long)]
pub leader_addr: Option<String>,
/// Internal use only.
// Start the python vllm engine sub-process.
#[arg(long)]
#[clap(hide = true, default_value = "false")]
pub internal_vllm_process: bool,
/// Internal use only.
/// Start the sglang Python sub-process.
/// The params in the tuple are:
/// - the fd of the write end of a pipe where sglang will signal that it's ready.
/// - the node rank (0 for first host, 1 for second host, etc)
/// - the workers' rank (globally unique)
/// - the GPU to use (locally unique)
#[arg(long)]
#[clap(hide = true, value_parser = parse_sglang_flags)]
pub internal_sglang_process: Option<SgLangFlags>,
}
#[derive(Debug, Clone, Copy)]
pub struct SgLangFlags {
pub pipe_fd: u32,
pub tp_rank: u32,
pub gpu_id: u32,
}
fn parse_sglang_flags(s: &str) -> Result<SgLangFlags, String> {
let nums: Vec<u32> = s
.split(',')
.map(u32::from_str)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?;
if nums.len() != 3 {
return Err("Need exactly 3 numbers".into());
}
Ok(SgLangFlags {
pipe_fd: nums[0],
tp_rank: nums[1],
gpu_id: nums[2],
})
}
......@@ -78,6 +78,7 @@ pub async fn run(
EngineConfig::Dynamic(_) => {
anyhow::bail!("Cannot use endpoint for both in and out");
}
EngineConfig::None => unreachable!(),
};
let model_registration = ModelEntry {
......
......@@ -96,6 +96,7 @@ pub async fn run(
.model_manager()
.add_chat_completions_model(&service_name, pipeline)?;
}
EngineConfig::None => unreachable!(),
}
http_service.run(runtime.primary_token()).await
}
......@@ -91,6 +91,7 @@ pub async fn run(
tracing::info!("Model: {service_name} with pre-processing");
(service_name, pipeline, true)
}
EngineConfig::None => unreachable!(),
};
main_loop(cancel_token, &service_name, engine, inspect_template).await
}
......
......@@ -13,8 +13,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::PathBuf;
use std::str::FromStr;
#[cfg(any(feature = "vllm", feature = "sglang"))]
use std::{future::Future, pin::Pin};
use triton_distributed_llm::{
backend::ExecutionContext,
......@@ -29,8 +29,10 @@ use triton_distributed_llm::{
};
use triton_distributed_runtime::{component::Client, protocols::Endpoint, DistributedRuntime};
mod flags;
pub use flags::Flags;
mod input;
#[cfg(feature = "sglang")]
#[cfg(any(feature = "vllm", feature = "sglang"))]
mod net;
mod opt;
mod output;
......@@ -41,90 +43,6 @@ pub use opt::{Input, Output};
/// concatenations.
const ENDPOINT_SCHEME: &str = "tdr://";
/// Required options depend on the in and out choices
#[derive(clap::Parser, Debug, Clone)]
#[command(version, about, long_about = None)]
pub struct Flags {
/// Full path to the model, which can be either a GGUF file or a checked out HF repository.
/// For the `echo_full` engine omit the flag.
#[arg(index = 1)]
pub model_path_pos: Option<PathBuf>,
// `--model-path`. The one above is `tio <positional-model-path>`
#[arg(long = "model-path")]
pub model_path_flag: Option<PathBuf>,
/// HTTP port. `in=http` only
#[arg(long, default_value = "8080")]
pub http_port: u16,
/// The name of the model we are serving
#[arg(long)]
pub model_name: Option<String>,
/// llamacpp only
///
/// The path to the tokenizer and model config because:
/// - llama_cpp only runs GGUF files
/// - our engine is a 'core' engine in that we do the tokenization, so we need the vocab
/// - TODO: we don't yet extract that from the GGUF. Once we do we can remove this flag.
#[arg(long)]
pub model_config: Option<PathBuf>,
/// sglang and trtllm only
///
/// How many GPUs to use at once, total across all nodes.
/// This must divide by num_nodes, and each node must use the same number of GPUs.
#[arg(long, default_value = "1", value_parser = clap::value_parser!(u32).range(1..256))]
pub tensor_parallel_size: u32,
/// sglang only
///
/// Use GPUs from this ID upwards.
/// If your machine has four GPUs but the first two (0 and 1) are in use,
/// pass --base-gpu-id 2 to use the third GPU (and up, if tensor_parallel_size > 1)
#[arg(long, default_value = "0", value_parser = clap::value_parser!(u32).range(0..256))]
pub base_gpu_id: u32,
/// sglang only
///
/// How many nodes/hosts to use
#[arg(long, default_value = "1", value_parser = clap::value_parser!(u32).range(1..256))]
pub num_nodes: u32,
/// sglang only
///
/// This nodes' unique ID, running from 0 to num_nodes.
#[arg(long, default_value = "0", value_parser = clap::value_parser!(u32).range(0..255))]
pub node_rank: u32,
/// sglang only
///
/// The Torch Distributed init method address, in format <host>:<port>.
/// It becomes "tcp://<host>:<port>" when given to torch.distributed.init_process_group.
/// This expects to use the nccl backend (transparently to us here).
/// All nodes must use the same dist_init_addr, which is node_rank == 0's address.
#[arg(long)]
pub dist_init_addr: Option<String>,
/// Internal use only.
// Start the python vllm engine sub-process.
#[arg(long)]
#[clap(hide = true, default_value = "false")]
pub internal_vllm_process: bool,
/// Internal use only.
/// Start the sglang Python sub-process.
/// The params in the tuple are:
/// - the fd of the write end of a pipe where sglang will signal that it's ready.
/// - the node rank (0 for first host, 1 for second host, etc)
/// - the workers' rank (globally unique)
/// - the GPU to use (locally unique)
#[arg(long)]
#[clap(hide = true, value_parser = parse_sglang_flags)]
pub internal_sglang_process: Option<SgLangFlags>,
}
pub enum EngineConfig {
/// An remote networked engine we don't know about yet
/// We don't have the pre-processor yet so this is only text requests. Type will change later.
......@@ -142,35 +60,15 @@ pub enum EngineConfig {
engine: ExecutionContext,
card: Box<ModelDeploymentCard>,
},
}
#[derive(Debug, Clone, Copy)]
pub struct SgLangFlags {
pub pipe_fd: u32,
pub tp_rank: u32,
pub gpu_id: u32,
}
fn parse_sglang_flags(s: &str) -> Result<SgLangFlags, String> {
let nums: Vec<u32> = s
.split(',')
.map(u32::from_str)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?;
if nums.len() != 3 {
return Err("Need exactly 3 numbers".into());
}
Ok(SgLangFlags {
pipe_fd: nums[0],
tp_rank: nums[1],
gpu_id: nums[2],
})
/// vllm multi-node doesn't run an engine on nodes other than 0. 'ray' does all the work.
None,
}
#[allow(unused_mut)]
pub async fn run(
runtime: triton_distributed_runtime::Runtime,
in_opt: Input,
mut in_opt: Input, // mut because vllm and sglang multi-node can change it
out_opt: Output,
flags: Flags,
#[allow(unused_variables)] zmq_socket_prefix: Option<String>,
......@@ -212,7 +110,7 @@ pub async fn run(
};
#[cfg(any(feature = "vllm", feature = "sglang"))]
let mut extra = None; // vllm and sglang sub-process
let mut extra: Option<Pin<Box<dyn Future<Output = ()> + Send>>> = None; // vllm and sglang sub-process
// Create the engine matching `out`
let engine_config = match out_opt {
......@@ -293,10 +191,10 @@ pub async fn run(
let Some(sock_prefix) = zmq_socket_prefix else {
anyhow::bail!("sglang requires zmq_socket_prefix");
};
let node_conf = sglang::MultiNodeConfig {
let node_conf = triton_distributed_llm::engines::MultiNodeConfig {
num_nodes: flags.num_nodes,
node_rank: flags.node_rank,
dist_init_addr: flags.dist_init_addr,
leader_addr: flags.leader_addr.unwrap_or_default(),
};
if node_conf.num_nodes > 1 {
if let Ok(Some(if_name)) = net::get_primary_interface().await {
......@@ -304,6 +202,11 @@ pub async fn run(
tracing::info!("export GLOO_SOCKET_IFNAME={if_name}");
tracing::info!("export NCCL_SOCKET_IFNAME={if_name}");
}
if node_conf.node_rank != 0 {
// Follower nodes take input from leader node over pytorch distributed, not
// from user.
in_opt = Input::None;
}
}
let (engine, sglang_process) = sglang::make_engine(
......@@ -315,7 +218,9 @@ pub async fn run(
flags.base_gpu_id,
)
.await?;
extra = Some(sglang_process);
extra = Some(Box::pin(async move {
let _ = sglang_process.await;
}));
EngineConfig::StaticCore {
service_name: card.service_name.clone(),
engine,
......@@ -325,6 +230,9 @@ pub async fn run(
#[cfg(feature = "vllm")]
Output::Vllm => {
use triton_distributed_llm::engines::vllm;
if flags.base_gpu_id != 0 {
anyhow::bail!("vllm does not support base_gpu_id. Set environment variable CUDA_VISIBLE_DEVICES instead.");
}
let Some(model_path) = model_path else {
anyhow::bail!(
"out=vllm requires flag --model-path=<full-path-to-hf-repo-or-model-gguf>"
......@@ -345,19 +253,49 @@ pub async fn run(
let Some(sock_prefix) = zmq_socket_prefix else {
anyhow::bail!("vllm requires zmq_socket_prefix");
};
let (engine, vllm_process) =
vllm::make_engine(cancel_token.clone(), &card_path, &model_path, &sock_prefix)
.await?;
extra = Some(vllm_process);
EngineConfig::StaticCore {
service_name: card.service_name.clone(),
engine,
card: Box::new(card),
let node_conf = triton_distributed_llm::engines::MultiNodeConfig {
num_nodes: flags.num_nodes,
node_rank: flags.node_rank,
leader_addr: flags.leader_addr.unwrap_or_default(),
};
if node_conf.num_nodes > 1 {
if let Ok(Some(if_name)) = net::get_primary_interface().await {
tracing::info!("If you see network errors from vllm try setting this environment variable:");
tracing::info!("export NCCL_SOCKET_IFNAME={if_name}");
}
if node_conf.node_rank != 0 {
// Only node 0 runs vllm, the others communicate over ray
in_opt = Input::None;
}
}
if node_conf.node_rank == 0 {
// vllm multi-node only the leader runs vllm
let (engine, vllm_future) = vllm::make_leader_engine(
cancel_token.clone(),
&card_path,
&model_path,
&sock_prefix,
node_conf,
flags.tensor_parallel_size,
)
.await?;
extra = Some(Box::pin(async move {
let _ = vllm_future.await;
}));
EngineConfig::StaticCore {
service_name: card.service_name.clone(),
engine,
card: Box::new(card),
}
} else {
// Nodes rank > 0 only run 'ray'
let stop_future = vllm::start_follower(cancel_token.clone(), node_conf).await?;
extra = Some(Box::pin(stop_future));
EngineConfig::None
}
}
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => {
use anyhow::Context;
use triton_distributed_llm::engines::llamacpp;
let Some(model_path) = model_path else {
anyhow::bail!("out=llamacpp requires flag --model-path=<full-path-to-model-gguf>");
......@@ -419,9 +357,8 @@ pub async fn run(
#[cfg(any(feature = "vllm", feature = "sglang"))]
// Allow engines to ask main thread to wait on an extra future.
// vllm and sglang use this to shut down sub-process
if let Some(extra) = extra {
extra.await?;
extra.await;
}
Ok(())
......
......@@ -41,7 +41,7 @@ const DEFAULT_OUT: Output = Output::EchoFull;
const ZMQ_SOCKET_PREFIX: &str = "tio";
const USAGE: &str = "USAGE: tio in=[http|text|tdr://<path>|none] out=[mistralrs|sglang|llamacpp|vllm|trtllm|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--dist-init-addr=127.0.0.1:9876] [--base-gpu-id=0]";
const USAGE: &str = "USAGE: tio in=[http|text|tdr://<path>|none] out=[mistralrs|sglang|llamacpp|vllm|trtllm|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0]";
fn main() -> anyhow::Result<()> {
logging::init();
......@@ -66,10 +66,10 @@ fn main() -> anyhow::Result<()> {
tp_rank: sglang_flags.tp_rank,
gpu_id: sglang_flags.gpu_id,
};
let node_config = sglang::MultiNodeConfig {
let node_config = triton_distributed_llm::engines::MultiNodeConfig {
num_nodes: flags.num_nodes,
node_rank: flags.node_rank,
dist_init_addr: flags.dist_init_addr,
leader_addr: flags.leader_addr.unwrap_or_default(),
};
return sglang::run_subprocess(
ZMQ_SOCKET_PREFIX,
......@@ -99,7 +99,18 @@ fn main() -> anyhow::Result<()> {
#[cfg(feature = "vllm")]
{
use triton_distributed_llm::engines::vllm;
return vllm::run_subprocess(ZMQ_SOCKET_PREFIX, &model_config, &model_path);
let node_config = triton_distributed_llm::engines::MultiNodeConfig {
num_nodes: flags.num_nodes,
node_rank: flags.node_rank,
leader_addr: flags.leader_addr.unwrap_or_default(),
};
return vllm::run_subprocess(
ZMQ_SOCKET_PREFIX,
&model_config,
&model_path,
node_config,
flags.tensor_parallel_size,
);
}
} else {
panic!("Rebuild with --features=vllm");
......
......@@ -27,3 +27,23 @@ pub mod vllm;
#[cfg(feature = "trtllm")]
pub mod trtllm;
#[derive(Debug, Clone)]
pub struct MultiNodeConfig {
/// How many nodes / hosts we are using
pub num_nodes: u32,
/// Unique consecutive integer to identify this node
pub node_rank: u32,
/// host:port of head / control node
pub leader_addr: String,
}
impl Default for MultiNodeConfig {
fn default() -> Self {
MultiNodeConfig {
num_nodes: 1,
node_rank: 0,
leader_addr: "".to_string(),
}
}
}
......@@ -34,11 +34,8 @@ pub async fn make_engine(
model_path: &Path,
// Unique string to name zmq sockets
sock_code: &str,
// Multi node settings:
// - num_nodes: How many nodes/hosts we are using
// - node_rank: Unique consecutive int starting at 0 to identify this node
// - dist_init_addr: Torch Distributed init method addr:port
node_conf: MultiNodeConfig,
// Multi node settings
node_conf: super::MultiNodeConfig,
// How many GPUs to use
tensor_parallel_size: u32,
// The base GPU ID to start allocating GPUs from
......@@ -77,23 +74,3 @@ impl Default for MultiGPUConfig {
}
}
}
#[derive(Debug, Clone)]
pub struct MultiNodeConfig {
/// How many nodes / hosts we are using
pub num_nodes: u32,
/// Unique consecutive integer to identify this node
pub node_rank: u32,
/// host:port of head / control node
pub dist_init_addr: Option<String>,
}
impl Default for MultiNodeConfig {
fn default() -> Self {
MultiNodeConfig {
num_nodes: 1,
node_rank: 0,
dist_init_addr: None,
}
}
}
......@@ -24,7 +24,7 @@ use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
use triton_distributed_runtime::runtime::CancellationToken;
use crate::engines::sglang::MultiNodeConfig;
use crate::engines::MultiNodeConfig;
pub struct SgLangEngine {
cancel_token: CancellationToken,
......
......@@ -16,6 +16,8 @@
use pyo3::{types::IntoPyDict, Python};
use std::{os::fd::RawFd, path::Path};
use crate::engines::MultiNodeConfig;
const PY_START_ENGINE: &std::ffi::CStr = cr#"
from multiprocessing.connection import Connection
import signal
......@@ -86,7 +88,7 @@ pub fn run_subprocess(
// The write half of a pipe, where sglang will signal when it's ready
notify_pipe_fd: RawFd,
// Multi node. Usually Default::default
node_config: super::MultiNodeConfig,
node_config: MultiNodeConfig,
// Multi GPU. Usually Default::default
gpu_config: super::MultiGPUConfig,
) -> anyhow::Result<()> {
......@@ -103,10 +105,7 @@ pub fn run_subprocess(
("gpu_id_str", &gpu_config.gpu_id.to_string()),
("nnodes_str", &node_config.num_nodes.to_string()),
("node_rank_str", &node_config.node_rank.to_string()),
(
"dist_init_addr",
&node_config.dist_init_addr.unwrap_or_default().to_string(),
),
("dist_init_addr", &node_config.leader_addr),
]
.into_py_dict(py)
.unwrap();
......
......@@ -40,7 +40,8 @@ use tokio::{io::AsyncReadExt as _, task::JoinHandle};
use triton_distributed_runtime::protocols::annotated::Annotated;
use triton_distributed_runtime::runtime::CancellationToken;
use crate::engines::sglang::{MultiGPUConfig, MultiNodeConfig};
use crate::engines::sglang::MultiGPUConfig;
use crate::engines::MultiNodeConfig;
use crate::protocols::common::llm_backend::LLMEngineOutput;
use crate::protocols::common::preprocessor::PreprocessedRequest;
use crate::protocols::common::FinishReason;
......@@ -302,9 +303,10 @@ pub async fn start(
let py_imports = Arc::new(python_imports());
if tp_size < node_conf.num_nodes {
anyhow::bail!("Need at least as many GPUs as nodes. In nio set --tensor-parallel-size >= --num-nodes.");
anyhow::bail!(
"Need at least as many GPUs as nodes. Pass --tensor-parallel-size >= --num-nodes."
);
}
let tp_size_per_node = tp_size / node_conf.num_nodes;
let tp_rank_start = tp_size_per_node * node_conf.node_rank;
let tp_rank_end = tp_size_per_node * (node_conf.node_rank + 1);
......@@ -460,8 +462,11 @@ async fn start_sglang(
format!("--num-nodes={}", node_conf.num_nodes),
format!("--node-rank={}", node_conf.node_rank),
];
if let Some(dist_init_addr) = node_conf.dist_init_addr {
args.push(format!("--dist-init-addr={dist_init_addr}"));
if node_conf.num_nodes > 1 {
if node_conf.leader_addr.is_empty() {
anyhow::bail!("Missing --leader-addr for multi-node");
}
args.push(format!("--leader-addr={}", node_conf.leader_addr));
}
let self_path = std::env::current_exe()?;
let mut proc = tokio::process::Command::new(self_path)
......
......@@ -13,23 +13,30 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::future::Future;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use triton_distributed_runtime::pipeline::error as pipeline_error;
use triton_distributed_runtime::CancellationToken;
use crate::backend::ExecutionContext;
mod worker;
use crate::engines::MultiNodeConfig;
mod engine;
use engine::VllmEngine;
mod ray;
use ray::Ray;
mod subprocess;
pub use subprocess::run_subprocess;
pub async fn make_engine(
mod worker;
pub async fn make_leader_engine(
cancel_token: CancellationToken,
// Where to find the tokenzier, and config.json
card_path: &Path,
......@@ -37,9 +44,103 @@ pub async fn make_engine(
model_path: &Path,
// Unique string to name zmq sockets
sock_code: &str,
) -> pipeline_error::Result<(ExecutionContext, tokio::task::JoinHandle<()>)> {
let mut engine = VllmEngine::new(cancel_token, sock_code, card_path, model_path).await?;
// Multi node settings
node_conf: MultiNodeConfig,
// How many GPUs to use
tensor_parallel_size: u32,
) -> pipeline_error::Result<(ExecutionContext, impl Future<Output = ()>)> {
let ray_obj = if node_conf.num_nodes > 1 {
let r = ray::start_leader(node_conf.leader_addr.parse()?)?;
tracing::info!("Leader waiting for {} total nodes", node_conf.num_nodes);
r.wait_for(cancel_token.clone(), node_conf.num_nodes)
.await?;
tracing::info!("All nodes registered");
Some(r)
} else {
None
};
let mut engine = VllmEngine::new(
cancel_token,
sock_code,
card_path,
model_path,
node_conf,
tensor_parallel_size,
)
.await?;
let vllm_process = engine.take_vllm_worker_handle();
let vllm_future = async move {
if let Err(err) = vllm_process.await {
tracing::error!("Failed stopping vllm process: {err:#}");
}
if let Some(r) = ray_obj {
if let Err(err) = r.stop().await {
tracing::error!("Failed stopping ray: {err:#}");
}
}
};
let engine: ExecutionContext = Arc::new(engine);
Ok((engine, vllm_process))
Ok((engine, vllm_future))
}
pub async fn start_follower(
cancel_token: CancellationToken,
node_conf: MultiNodeConfig,
) -> pipeline_error::Result<StopFuture> {
let r = ray::start_follower(node_conf.leader_addr.parse()?)?;
tracing::info!("Follower waiting for {} total nodes", node_conf.num_nodes);
r.wait_for(cancel_token, node_conf.num_nodes).await?;
tracing::info!("All nodes registered");
Ok(StopFuture {
state: Some(StopFutureState::New(r)),
})
}
pub struct StopFuture {
state: Option<StopFutureState>,
}
enum StopFutureState {
New(Ray),
Running(Pin<Box<dyn Future<Output = ()> + Send>>),
}
impl Future for StopFuture {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let state = match self.state.take() {
None => return Poll::Ready(()),
Some(state) => state,
};
match state {
StopFutureState::New(obj) => {
// Convert object to a stop future
let future = Box::pin(async move {
if let Err(err) = obj.stop().await {
tracing::error!("Failed calling 'ray stop': {err:#}");
}
});
self.state = Some(StopFutureState::Running(future));
// Recurse to poll the new future immediately
self.poll(cx)
}
StopFutureState::Running(mut future) => {
// Poll the stop future
match future.as_mut().poll(cx) {
Poll::Ready(()) => {
// Done, leave state as None
Poll::Ready(())
}
Poll::Pending => {
// Not ready yet, preserve the future
self.state = Some(StopFutureState::Running(future));
Poll::Pending
}
}
}
}
}
}
......@@ -19,6 +19,7 @@ use async_stream::stream;
use async_trait::async_trait;
use crate::engines::vllm::worker;
use crate::engines::MultiNodeConfig;
use crate::protocols::common::llm_backend::{BackendInput, LLMEngineOutput};
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
......@@ -36,8 +37,18 @@ impl VllmEngine {
sock_code: &str,
card_path: &Path,
model_path: &Path,
node_conf: MultiNodeConfig,
tensor_parallel_size: u32,
) -> anyhow::Result<Self> {
let w = worker::start(cancel_token.clone(), sock_code, card_path, model_path).await?;
let w = worker::start(
cancel_token.clone(),
sock_code,
card_path,
model_path,
node_conf,
tensor_parallel_size,
)
.await?;
let engine = VllmEngine {
cancel_token,
worker: w,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use regex::Regex;
use std::io::{BufRead, BufReader};
use std::net::SocketAddrV4;
use std::process::{Command, Stdio};
use std::time::Duration;
use thiserror::Error;
use tokio::io::AsyncBufReadExt;
use tokio::select;
use tokio::time;
use tracing;
use triton_distributed_runtime::CancellationToken;
/// Default is 16 seconds, we make it a bit shorter
const RAY_STOP_TIMEOUT_SECS: u32 = 10;
/// How long to wait for all the nodes to start.
/// This is either done manually or through some orchestration system, so either way it
/// can take some time.
const RAY_WAIT_SECS: u32 = 60 * 5;
#[derive(Debug, Error)]
pub enum RayError {
#[error("Failed to execute Ray command: {0}")]
CommandExecution(#[from] std::io::Error),
#[error("Ray command failed with exit code: {0}")]
CommandFailed(i32),
#[error("Failed to parse Ray status output")]
StatusParseError,
#[error("Timeout waiting for nodes to become active")]
WaitTimeout,
#[error("Operation cancelled")]
Cancelled,
}
#[derive(Debug, PartialEq)]
pub struct RayStatus {
pub active_nodes: Vec<String>,
pub pending_nodes_count: usize,
pub recent_failures_count: usize,
}
pub struct Ray {
#[allow(dead_code)]
leader_address: SocketAddrV4,
}
pub fn start_leader(leader_address: SocketAddrV4) -> Result<Ray, RayError> {
let ip = leader_address.ip().to_string();
let port = leader_address.port().to_string();
let mut cmd = Command::new("ray");
cmd.args([
"start",
"--head",
"--disable-usage-stats",
"--log-style=record",
&format!("--node-ip-address={}", ip),
&format!("--port={}", port),
]);
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
let mut child = cmd.spawn()?;
// Process stdout
if let Some(stdout) = child.stdout.take() {
let reader = BufReader::new(stdout);
for line in reader.lines() {
if let Ok(line) = line {
tracing::info!("RAY: {line}");
}
}
}
// Process stderr
if let Some(stderr) = child.stderr.take() {
let reader = BufReader::new(stderr);
for line in reader.lines() {
if let Ok(line) = line {
tracing::info!("RAY: {line}");
}
}
}
let status = child.wait()?;
if !status.success() {
return Err(RayError::CommandFailed(status.code().unwrap_or(-1)));
}
Ok(Ray { leader_address })
}
pub fn start_follower(leader_address: SocketAddrV4) -> Result<Ray, RayError> {
let address = leader_address.to_string();
let mut cmd = Command::new("ray");
cmd.args(["start", &format!("--address={address}")]);
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
let mut child = cmd.spawn()?;
// Process stdout
if let Some(stdout) = child.stdout.take() {
let reader = BufReader::new(stdout);
for line in reader.lines() {
if let Ok(line) = line {
tracing::info!("RAY: {line}");
}
}
}
// Process stderr
if let Some(stderr) = child.stderr.take() {
let reader = BufReader::new(stderr);
for line in reader.lines() {
if let Ok(line) = line {
tracing::info!("RAY: {line}");
}
}
}
let status = child.wait()?;
if !status.success() {
return Err(RayError::CommandFailed(status.code().unwrap_or(-1)));
}
Ok(Ray { leader_address })
}
impl Ray {
pub fn status(&self) -> Result<RayStatus, RayError> {
let output = Command::new("ray").arg("status").output()?;
if !output.status.success() {
return Err(RayError::CommandFailed(output.status.code().unwrap_or(-1)));
}
let output_str = String::from_utf8_lossy(&output.stdout);
parse_ray_status(&output_str).ok_or(RayError::StatusParseError)
}
pub async fn wait_for(
&self,
cancel_token: CancellationToken,
num_nodes: u32,
) -> Result<(), RayError> {
let timeout = time::sleep(Duration::from_secs(RAY_WAIT_SECS as u64));
select! {
_ = cancel_token.cancelled() => {
Err(RayError::Cancelled)
}
_ = timeout => {
Err(RayError::WaitTimeout)
}
result = self.wait_for_nodes(num_nodes) => {
result
}
}
}
async fn wait_for_nodes(&self, num_nodes: u32) -> Result<(), RayError> {
loop {
let status = self.status()?;
if status.active_nodes.len() as u32 == num_nodes {
return Ok(());
}
time::sleep(Duration::from_millis(100)).await;
}
}
pub async fn stop(&self) -> Result<(), RayError> {
let mut cmd = tokio::process::Command::new("ray");
cmd.args([
"stop",
&format!("--grace-period={RAY_STOP_TIMEOUT_SECS}"),
"--log-style=record",
]);
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
let mut child = cmd.spawn()?;
// Process stdout
if let Some(stdout) = child.stdout.take() {
let reader = tokio::io::BufReader::new(stdout);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
tracing::info!("RAY: {line}");
}
}
// Process stderr
if let Some(stderr) = child.stderr.take() {
let reader = tokio::io::BufReader::new(stderr);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
tracing::info!("RAY: {line}");
}
}
let status = child.wait().await?;
if !status.success() {
return Err(RayError::CommandFailed(status.code().unwrap_or(-1)));
}
Ok(())
}
}
/// Parse the output of "ray status" command into a RayStatus struct
fn parse_ray_status(output: &str) -> Option<RayStatus> {
let mut active_nodes = Vec::new();
let mut pending_nodes_count = 0;
let mut recent_failures_count = 0;
// Flags to track which section we're in
let mut in_active_section = false;
let mut in_pending_section = false;
let mut in_failures_section = false;
// Regex to match node IDs
let node_regex = Regex::new(r"(\d+)\s+(node_[a-f0-9]+)").unwrap();
for line in output.lines() {
let trimmed = line.trim();
if trimmed == "Active:" {
in_active_section = true;
in_pending_section = false;
in_failures_section = false;
continue;
} else if trimmed == "Pending:" {
in_active_section = false;
in_pending_section = true;
in_failures_section = false;
continue;
} else if trimmed == "Recent failures:" {
in_active_section = false;
in_pending_section = false;
in_failures_section = true;
continue;
} else if trimmed.starts_with("Resources") {
// We've reached the end of the node status section
break;
}
if in_active_section {
if let Some(captures) = node_regex.captures(trimmed) {
if let Some(node_id) = captures.get(2) {
active_nodes.push(node_id.as_str().to_string());
}
}
} else if in_pending_section && trimmed != "(no pending nodes)" {
// Count pending nodes
if let Some(captures) = Regex::new(r"(\d+)").unwrap().captures(trimmed) {
if let Some(count) = captures.get(1) {
if let Ok(count) = count.as_str().parse::<usize>() {
pending_nodes_count += count;
}
}
}
} else if in_failures_section && trimmed != "(no failures)" {
// Count failures
if let Some(captures) = Regex::new(r"(\d+)").unwrap().captures(trimmed) {
if let Some(count) = captures.get(1) {
if let Ok(count) = count.as_str().parse::<usize>() {
recent_failures_count += count;
}
}
}
}
}
Some(RayStatus {
active_nodes,
pending_nodes_count,
recent_failures_count,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_ray_status() {
let sample_output = r#"======== Autoscaler status: 2025-03-04 13:13:59.104771 ========
Node status
---------------------------------------------------------------
Active:
1 node_b09a7440bd0987680f97c35206b2475251907d0c928fdd0f52b1b38f
1 node_035ea3b640e13f3603d3debd97de8c569ed8c8b10e19ce00ea4fd070
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/256.0 CPU
0.0/16.0 GPU
0B/1.58TiB memory
0B/372.53GiB object_store_memory
Demands:
(no resource demands)
"#;
let expected = RayStatus {
active_nodes: vec![
"node_b09a7440bd0987680f97c35206b2475251907d0c928fdd0f52b1b38f".to_string(),
"node_035ea3b640e13f3603d3debd97de8c569ed8c8b10e19ce00ea4fd070".to_string(),
],
pending_nodes_count: 0,
recent_failures_count: 0,
};
let result = parse_ray_status(sample_output);
assert!(result.is_some());
assert_eq!(result.unwrap(), expected);
}
/// Test with pending nodes and failures
#[test]
fn test_parse_ray_status_with_failing() {
let sample_output_with_pending = r#"======== Autoscaler status: 2025-03-04 13:13:59.104771 ========
Node status
---------------------------------------------------------------
Active:
1 node_b09a7440bd0987680f97c35206b2475251907d0c928fdd0f52b1b38f
Pending:
2 node_pending_1
3 node_pending_2
Recent failures:
1 node_failure_1
4 node_failure_2
Resources
---------------------------------------------------------------
Usage:
0.0/256.0 CPU
"#;
let expected_with_pending = RayStatus {
active_nodes: vec![
"node_b09a7440bd0987680f97c35206b2475251907d0c928fdd0f52b1b38f".to_string(),
],
pending_nodes_count: 5, // 2 + 3
recent_failures_count: 5, // 1 + 4
};
let result = parse_ray_status(sample_output_with_pending);
assert!(result.is_some());
assert_eq!(result.unwrap(), expected_with_pending);
}
/// Test with empty output
#[test]
fn test_parse_ray_status_empty() {
let empty_output = "";
let result = parse_ray_status(empty_output);
assert!(result.is_some());
assert_eq!(result.unwrap().active_nodes.len(), 0);
}
}
......@@ -16,6 +16,8 @@
use pyo3::{types::IntoPyDict, Python};
use std::path::Path;
use crate::engines::MultiNodeConfig;
const PY_START_ENGINE: &std::ffi::CStr = cr#"
import multiprocessing
import signal
......@@ -24,7 +26,18 @@ from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.usage.usage_lib import UsageContext
engine_args = AsyncEngineArgs(model=f"{model_path}", served_model_name=None, tokenizer=f"{tokenizer_path}", task='generate', tokenizer_mode='auto', seed=0, max_model_len=8192, max_seq_len_to_capture=8192)
engine_args = AsyncEngineArgs(
model=f"{model_path}",
served_model_name=None,
tokenizer=f"{tokenizer_path}",
task='generate',
tokenizer_mode='auto',
seed=0,
max_model_len=8192,
max_seq_len_to_capture=8192,
tensor_parallel_size = int(tp_size_str),
pipeline_parallel_size = int(nnodes_str),
)
ipc_path = f"ipc:///tmp/{socket_id}";
......@@ -39,6 +52,8 @@ pub fn run_subprocess(
socket_id: &str,
model_card_path: &Path,
model_path: &Path,
node_config: MultiNodeConfig,
tp_size: u32,
) -> anyhow::Result<()> {
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
let card = model_card_path.display().to_string();
......@@ -48,6 +63,8 @@ pub fn run_subprocess(
("socket_id", socket_id),
("tokenizer_path", card.as_str()),
("model_path", model_path_str.as_str()),
("tp_size_str", &tp_size.to_string()),
("nnodes_str", &node_config.num_nodes.to_string()),
]
.into_py_dict(py)
.unwrap();
......
......@@ -29,6 +29,7 @@ use tokio::{io::AsyncBufReadExt, sync::mpsc::error::SendError};
use triton_distributed_runtime::protocols::annotated::Annotated;
use triton_distributed_runtime::CancellationToken;
use crate::engines::MultiNodeConfig;
use crate::protocols::common::llm_backend::LLMEngineOutput;
use crate::protocols::common::preprocessor::PreprocessedRequest;
use crate::protocols::common::FinishReason;
......@@ -156,6 +157,8 @@ pub async fn start(
sock_code: &str,
card_path: &Path,
model_path: &Path,
_node_conf: MultiNodeConfig,
tensor_parallel_size: u32,
) -> anyhow::Result<VllmWorker> {
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
......@@ -168,7 +171,14 @@ pub async fn start(
heartbeat,
} = zmq_sockets(sock_code)?;
let vllm_process = start_vllm(card_path, model_path, &py_imports, data).await?;
let vllm_process = start_vllm(
card_path,
model_path,
&py_imports,
data,
tensor_parallel_size,
)
.await?;
let vllm_join_handle = watch_vllm(cancel_token.clone(), vllm_process);
tokio::spawn(heartbeat_loop(cancel_token.clone(), heartbeat));
......@@ -285,12 +295,13 @@ async fn start_vllm(
model_path: &Path,
python_imports: &Imports,
mut data_socket: async_zmq::Dealer<IntoIter<Vec<u8>>, Vec<u8>>,
tensor_parallel_size: u32,
) -> anyhow::Result<tokio::process::Child> {
// The in/out args are not used but we currently require them for parsing cli args
let vllm_args = [
"--internal-vllm-process",
&format!("--model-config={}", card_path.display()),
&format!("--model-path={}", model_path.display()),
&format!("--tensor-parallel-size={tensor_parallel_size}"),
];
let self_path = std::env::current_exe()?;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment