"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "15462b741263709b5660bc8dec6ebff57cb54d36"
Commit e1a95dab authored by Ryan McCormick's avatar Ryan McCormick Committed by GitHub
Browse files

fix: Support vllm_nixl (custom vllm patch) from dynamo-run (#84)

parent 86bc5442
...@@ -13,10 +13,13 @@ ...@@ -13,10 +13,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{ use std::collections::HashMap;
collections::HashMap, ops::Deref, path::Path, process::Stdio, sync::Arc, time::Duration, use std::ops::Deref;
vec::IntoIter, use std::path::Path;
}; use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use std::vec::IntoIter;
use async_zmq::{SinkExt, StreamExt}; use async_zmq::{SinkExt, StreamExt};
use dynamo_runtime::protocols::annotated::Annotated; use dynamo_runtime::protocols::annotated::Annotated;
...@@ -25,11 +28,12 @@ use pyo3::{ ...@@ -25,11 +28,12 @@ use pyo3::{
prelude::*, prelude::*,
types::{IntoPyDict, PyBytes, PyString}, types::{IntoPyDict, PyBytes, PyString},
}; };
use tokio::sync::mpsc::Sender; use tokio::io::AsyncBufReadExt;
use tokio::sync::mpsc::{error::SendError, Sender};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio::{io::AsyncBufReadExt, sync::mpsc::error::SendError};
use crate::engines::MultiNodeConfig; use crate::engines::MultiNodeConfig;
use crate::kv_router::protocols::ForwardPassMetrics;
use crate::protocols::common::llm_backend::LLMEngineOutput; use crate::protocols::common::llm_backend::LLMEngineOutput;
use crate::protocols::common::preprocessor::PreprocessedRequest; use crate::protocols::common::preprocessor::PreprocessedRequest;
use crate::protocols::common::FinishReason; use crate::protocols::common::FinishReason;
...@@ -104,6 +108,10 @@ struct Sockets { ...@@ -104,6 +108,10 @@ struct Sockets {
output: async_zmq::Pull, output: async_zmq::Pull,
// Heartbeat messages from vllm process // Heartbeat messages from vllm process
heartbeat: async_zmq::Pull, heartbeat: async_zmq::Pull,
// NOTE: Metrics socket usage is custom to our patch of vllm, and may not
// be present when running upstream vllm.
// Metrics messages from vllm process
metrics: async_zmq::Pull,
} }
/// The message vllm sends us over zmq when it's ready to work. /// The message vllm sends us over zmq when it's ready to work.
...@@ -169,6 +177,7 @@ pub async fn start( ...@@ -169,6 +177,7 @@ pub async fn start(
input, input,
output, output,
heartbeat, heartbeat,
metrics,
} = zmq_sockets(sock_code)?; } = zmq_sockets(sock_code)?;
let vllm_process = start_vllm( let vllm_process = start_vllm(
...@@ -182,6 +191,7 @@ pub async fn start( ...@@ -182,6 +191,7 @@ pub async fn start(
let vllm_join_handle = watch_vllm(cancel_token.clone(), vllm_process); let vllm_join_handle = watch_vllm(cancel_token.clone(), vllm_process);
tokio::spawn(heartbeat_loop(cancel_token.clone(), heartbeat)); tokio::spawn(heartbeat_loop(cancel_token.clone(), heartbeat));
tokio::spawn(metrics_loop(cancel_token.clone(), metrics));
let active_requests = Arc::new(tokio::sync::Mutex::new(HashMap::new())); let active_requests = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
let (tx, rx) = tokio::sync::mpsc::channel(8); let (tx, rx) = tokio::sync::mpsc::channel(8);
...@@ -280,12 +290,19 @@ fn zmq_sockets(sock_code: &str) -> anyhow::Result<Sockets> { ...@@ -280,12 +290,19 @@ fn zmq_sockets(sock_code: &str) -> anyhow::Result<Sockets> {
.with_context(&zmq_context) .with_context(&zmq_context)
.connect()?; .connect()?;
let metrics = async_zmq::pull(&format!("ipc:///tmp/{sock_code}_metrics_socket"))?
.with_context(&zmq_context)
.connect()?;
// TODO: NIXL/Prefill sockets here in the future for disagg?
Ok(Sockets { Ok(Sockets {
context: zmq_context, context: zmq_context,
data, data,
input, input,
output, output,
heartbeat, heartbeat,
metrics,
}) })
} }
...@@ -448,6 +465,97 @@ async fn heartbeat_loop(cancel_token: CancellationToken, mut socket: async_zmq:: ...@@ -448,6 +465,97 @@ async fn heartbeat_loop(cancel_token: CancellationToken, mut socket: async_zmq::
} }
} }
// NOTE: Custom to our patch of vllm.
async fn metrics_loop(cancel_token: CancellationToken, mut socket: async_zmq::Pull) {
loop {
let maybe_metrics = tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
maybe_metrics = socket.next() => {
maybe_metrics
}
};
let b = match maybe_metrics {
Some(Ok(b)) => b[0].deref().to_vec(),
Some(Err(err)) => {
tracing::error!("Error reading from vllm metrics socket: {err}");
break;
}
None => {
tracing::debug!("vllm metrics socket closed");
break;
}
};
// Try to deserialize directly into ForwardPassMetrics using Python's pickle module
let metrics_result = Python::with_gil(|py| -> Result<ForwardPassMetrics, String> {
let pickle = py
.import("pickle")
.map_err(|e| format!("Failed to import pickle: {}", e))?;
let loads = pickle
.getattr("loads")
.map_err(|e| format!("Failed to get loads function: {}", e))?;
let bytes = PyBytes::new(py, &b);
let result = loads
.call1((bytes,))
.map_err(|e| format!("Failed to call pickle.loads: {}", e))?;
// Try to extract the attributes from the Python object
let extract_field = |field: &str| -> Result<u64, String> {
result
.getattr(field)
.map_err(|e| format!("Field '{}' not found: {}", field, e))?
.extract::<u64>()
.map_err(|e| format!("Failed to extract '{}' as u64: {}", field, e))
};
let extract_float_field = |field: &str| -> Result<f32, String> {
result
.getattr(field)
.map_err(|e| format!("Field '{}' not found: {}", field, e))?
.extract::<f32>()
.map_err(|e| format!("Failed to extract '{}' as f32: {}", field, e))
};
// Give default values for any fields not found
let request_active_slots = extract_field("request_active_slots").unwrap_or(0);
let request_total_slots = extract_field("request_total_slots").unwrap_or(0);
let kv_active_blocks = extract_field("kv_active_blocks").unwrap_or(0);
let kv_total_blocks = extract_field("kv_total_blocks").unwrap_or(0);
let num_requests_waiting = extract_field("num_requests_waiting").unwrap_or(0);
let gpu_cache_usage_perc = extract_float_field("gpu_cache_usage_perc").unwrap_or(0.0);
let gpu_prefix_cache_hit_rate =
extract_float_field("gpu_prefix_cache_hit_rate").unwrap_or(0.0);
Ok(ForwardPassMetrics {
request_active_slots,
request_total_slots,
kv_active_blocks,
kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
})
});
match metrics_result {
Ok(metrics) => {
// TODO: These metrics could be attached to StatsHandler or Events
// for aggregation and visualization.
tracing::debug!("Received vllm metrics: {:?}", metrics);
}
Err(err) => {
tracing::error!(
"Error deserializing vllm metrics with Python pickle: {}",
err
);
}
}
}
}
fn from_vllm(output: CompletionOutput, previous_total_toks: usize) -> LLMEngineOutput { fn from_vllm(output: CompletionOutput, previous_total_toks: usize) -> LLMEngineOutput {
let finish_reason = match output.finish_reason.as_deref() { let finish_reason = match output.finish_reason.as_deref() {
Some("stop") => Some(FinishReason::Stop), Some("stop") => Some(FinishReason::Stop),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment