Unverified Commit 1f07dab7 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Add migration to LLM requests (#1930)

parent 5f179186
......@@ -162,6 +162,11 @@ pub struct Flags {
#[arg(long)]
pub request_template: Option<PathBuf>,
/// How many times a request can be migrated to another worker if the HTTP server lost
/// connection to the current worker.
#[arg(long, value_parser = clap::value_parser!(u32).range(0..1024))]
pub migration_limit: Option<u32>,
/// Everything after a `--`.
/// These are the command line arguments to the python engine when using `pystr` or `pytok`.
#[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)]
......@@ -180,6 +185,9 @@ impl Flags {
if self.kv_cache_block_size.is_some() {
anyhow::bail!("'--kv-cache-block-size' flag should only be used on the worker node, not on the ingress");
}
if self.migration_limit.is_some() {
anyhow::bail!("'--migration-limit' flag should only be used on the worker node, not on the ingress");
}
}
Output::EchoFull => {}
Output::EchoCore => {
......
......@@ -45,7 +45,8 @@ pub async fn run(
.context_length(flags.context_length)
.http_port(Some(flags.http_port))
.router_config(Some(flags.router_config()))
.request_template(flags.request_template.clone());
.request_template(flags.request_template.clone())
.migration_limit(flags.migration_limit);
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we let LocalModel invent one.
......
......@@ -48,6 +48,8 @@ pub async fn start(
card.kv_cache_block_size.to_string(),
"--context-length".to_string(),
card.context_length.to_string(),
"--migration-limit".to_string(),
card.migration_limit.to_string(),
];
// TRTLLM only
// The worker node will only publish events and metrics if the router mode is KV
......
......@@ -42,6 +42,7 @@ class Config:
nnodes: int
node_rank: int
dist_init_addr: str
migration_limit: int
extra_engine_args: str
......@@ -202,7 +203,13 @@ async def init(runtime: DistributedRuntime, config: Config):
model_type = (
ModelType.Backend if not engine_args.is_embedding else ModelType.Embedding
)
await register_llm(model_type, endpoint, config.model_path, config.model_name)
await register_llm(
model_type,
endpoint,
config.model_path,
config.model_name,
migration_limit=config.migration_limit,
)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
......@@ -268,6 +275,12 @@ def cmd_line_args():
default="",
help="Host address (e.g., `192.168.0.2:25000`) of the node with rank 0",
)
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
......@@ -304,6 +317,7 @@ def cmd_line_args():
config.nnodes = args.nnodes
config.node_rank = args.node_rank
config.dist_init_addr = args.dist_init_addr
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args
return config
......
......@@ -122,6 +122,7 @@ class Config:
model_name: Optional[str] = None
tensor_parallel_size: int
kv_block_size: int
migration_limit: int
extra_engine_args: str
publish_events_and_metrics: bool
disaggregation_mode: str
......@@ -136,6 +137,7 @@ class Config:
f"model_name={self.model_name}, "
f"tensor_parallel_size={self.tensor_parallel_size}, "
f"kv_block_size={self.kv_block_size}, "
f"migration_limit={self.migration_limit}, "
f"extra_engine_args={self.extra_engine_args}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, "
......@@ -404,6 +406,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.model_path,
config.model_name,
kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
)
# publisher will be set later if publishing is enabled.
......@@ -476,6 +479,12 @@ def cmd_line_args():
default=None,
help="This argument is not used by TRTLLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
)
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
......@@ -557,6 +566,7 @@ def cmd_line_args():
config.endpoint = parsed_endpoint_name
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args
config.publish_events_and_metrics = args.publish_events_and_metrics
config.disaggregation_mode = disaggregation_mode
......
......@@ -56,6 +56,7 @@ class Config:
tensor_parallel_size: int
kv_block_size: int
context_length: int
migration_limit: int
extra_engine_args: str
......@@ -233,6 +234,7 @@ async def init(runtime: DistributedRuntime, config: Config):
"max_model_len", None
), # if None, takes length from tokenizer
kv_cache_block_size=arg_map["block_size"],
migration_limit=config.migration_limit,
)
handler = RequestHandler(component, engine_client, default_sampling_params)
handler.setup_kv_metrics()
......@@ -276,6 +278,12 @@ def cmd_line_args():
default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
)
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
......@@ -308,6 +316,7 @@ def cmd_line_args():
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.context_length = args.context_length
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args
return config
......
......@@ -65,6 +65,7 @@ class Config:
tensor_parallel_size: int
kv_block_size: int
context_length: int
migration_limit: int
extra_engine_args: str
......@@ -218,6 +219,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.model_path,
config.model_name,
kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
)
arg_map = {
......@@ -333,6 +335,12 @@ def cmd_line_args():
default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
)
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
......@@ -365,6 +373,7 @@ def cmd_line_args():
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.context_length = args.context_length
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args
return config
......
......@@ -131,7 +131,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
}
#[pyfunction]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None))]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0))]
#[allow(clippy::too_many_arguments)]
fn register_llm<'p>(
py: Python<'p>,
......@@ -142,6 +142,7 @@ fn register_llm<'p>(
context_length: Option<u32>,
kv_cache_block_size: Option<u32>,
router_mode: Option<RouterMode>,
migration_limit: u32,
) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type {
ModelType::Chat => llm_rs::model_type::ModelType::Chat,
......@@ -162,7 +163,8 @@ fn register_llm<'p>(
.model_name(model_name)
.context_length(context_length)
.kv_cache_block_size(kv_cache_block_size)
.router_config(Some(router_config));
.router_config(Some(router_config))
.migration_limit(Some(migration_limit));
// Download from HF, load the ModelDeploymentCard
let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us
......
......@@ -19,6 +19,7 @@ use dynamo_runtime::{
use crate::{
backend::Backend,
kv_router::{KvPushRouter, KvRouterConfig},
migration::Migration,
model_type::ModelType,
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, PreprocessedRequest},
protocols::common::llm_backend::{EmbeddingsEngineOutput, LLMEngineOutput},
......@@ -197,12 +198,14 @@ impl ModelWatcher {
// function. Needs checking carefully, possibly we need to store it in state.
let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);
// Chat Completions
let frontend = SegmentSource::<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let migration = Migration::from_mdc(card.clone()).await?.into_operator();
let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
client.clone(),
......@@ -231,19 +234,23 @@ impl ModelWatcher {
let chat_engine = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(migration.forward_edge())?
.link(service_backend)?
.link(migration.backward_edge())?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
self.manager
.add_chat_completions_model(&model_entry.name, chat_engine)?;
// Completions
let frontend = SegmentSource::<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let migration = Migration::from_mdc(card.clone()).await?.into_operator();
let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
client,
......@@ -272,7 +279,9 @@ impl ModelWatcher {
let completions_engine = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(migration.forward_edge())?
.link(service_backend)?
.link(migration.backward_edge())?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
......
......@@ -22,6 +22,7 @@ pub mod hub;
// pub mod key_value_store;
pub mod kv_router;
pub mod local_model;
pub mod migration;
pub mod mocker;
pub mod model_card;
pub mod model_type;
......
......@@ -46,6 +46,7 @@ pub struct LocalModelBuilder {
router_config: Option<RouterConfig>,
kv_cache_block_size: u32,
http_port: u16,
migration_limit: u32,
}
impl Default for LocalModelBuilder {
......@@ -60,6 +61,7 @@ impl Default for LocalModelBuilder {
context_length: Default::default(),
template_file: Default::default(),
router_config: Default::default(),
migration_limit: Default::default(),
}
}
}
......@@ -112,6 +114,11 @@ impl LocalModelBuilder {
self
}
pub fn migration_limit(&mut self, migration_limit: Option<u32>) -> &mut Self {
self.migration_limit = migration_limit.unwrap_or(0);
self
}
/// Make an LLM ready for use:
/// - Download it from Hugging Face (and NGC in future) if necessary
/// - Resolve the path
......@@ -137,10 +144,12 @@ impl LocalModelBuilder {
// echo_full engine doesn't need a path. It's an edge case, move it out of the way.
if self.model_path.is_none() {
return Ok(LocalModel {
card: ModelDeploymentCard::with_name_only(
let mut card = ModelDeploymentCard::with_name_only(
self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
),
);
card.migration_limit = self.migration_limit;
return Ok(LocalModel {
card,
full_path: PathBuf::new(),
endpoint_id,
template,
......@@ -194,6 +203,8 @@ impl LocalModelBuilder {
card.context_length = context_length;
}
card.migration_limit = self.migration_limit;
Ok(LocalModel {
card,
full_path,
......
This diff is collapsed.
......@@ -92,6 +92,7 @@ impl ModelDeploymentCard {
last_published: None,
context_length,
kv_cache_block_size: 0,
migration_limit: 0,
})
}
......@@ -131,6 +132,7 @@ impl ModelDeploymentCard {
last_published: None,
context_length,
kv_cache_block_size: 0, // set later
migration_limit: 0,
})
}
}
......
......@@ -127,6 +127,10 @@ pub struct ModelDeploymentCard {
/// Size of a KV cache block - vllm only currently
/// Passed to the engine and the KV router.
pub kv_cache_block_size: u32,
/// How many times a request can be migrated to another worker if the HTTP server lost
/// connection to the current worker.
pub migration_limit: u32,
}
impl ModelDeploymentCard {
......
......@@ -136,11 +136,11 @@ impl LLMEngineOutput {
}
impl MaybeError for LLMEngineOutput {
fn from_err(err: Box<dyn std::error::Error>) -> Self {
fn from_err(err: Box<dyn std::error::Error + Send + Sync>) -> Self {
LLMEngineOutput::error(format!("{:?}", err))
}
fn err(&self) -> Option<Box<dyn std::error::Error>> {
fn err(&self) -> Option<Box<dyn std::error::Error + Send + Sync>> {
if let Some(FinishReason::Error(err_msg)) = &self.finish_reason {
Some(anyhow::Error::msg(err_msg.clone()).into())
} else {
......
......@@ -6,14 +6,8 @@ use crate::pipeline::{
SingleIn,
};
use arc_swap::ArcSwap;
use rand::Rng;
use std::collections::HashMap;
use std::sync::RwLock;
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
};
use std::time::Instant;
use std::sync::Arc;
use tokio::net::unix::pipe::Receiver;
use crate::{
......@@ -48,10 +42,8 @@ pub struct Client {
pub endpoint: Endpoint,
// These are the remotes I know about from watching etcd
pub instance_source: Arc<InstanceSource>,
// These are the instances that are reported as down from sending rpc
instance_inhibited: Arc<Mutex<HashMap<i64, Instant>>>,
// The current active IDs
instance_cache: Arc<ArcSwap<Vec<i64>>>,
// These are the instance source ids less those reported as down from sending rpc
instance_avail: Arc<ArcSwap<Vec<i64>>>,
}
#[derive(Clone, Debug)]
......@@ -60,16 +52,13 @@ pub enum InstanceSource {
Dynamic(tokio::sync::watch::Receiver<Vec<Instance>>),
}
// TODO: Avoid returning a full clone of `Vec<Instance>` everytime from Client
// See instances() and instances_avail() methods
impl Client {
// Client will only talk to a single static endpoint
pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
Ok(Client {
endpoint,
instance_source: Arc::new(InstanceSource::Static),
instance_inhibited: Arc::new(Mutex::new(HashMap::new())),
instance_cache: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
})
}
......@@ -85,26 +74,12 @@ impl Client {
let instance_source =
Self::get_or_create_dynamic_instance_source(etcd_client, &endpoint).await?;
let cancel_token = endpoint.drt().primary_token();
let client = Client {
endpoint,
instance_source,
instance_inhibited: Arc::new(Mutex::new(HashMap::new())),
instance_cache: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
};
let instance_source_c = client.instance_source.clone();
let instance_inhibited_c = Arc::clone(&client.instance_inhibited);
let instance_cache_c = Arc::clone(&client.instance_cache);
tokio::task::spawn(async move {
while !cancel_token.is_cancelled() {
refresh_instances(&instance_source_c, &instance_inhibited_c, &instance_cache_c);
tokio::select! {
_ = cancel_token.cancelled() => {}
_ = tokio::time::sleep(INSTANCE_REFRESH_PERIOD) => {}
}
}
});
client.monitor_instance_source();
Ok(client)
}
......@@ -119,13 +94,20 @@ impl Client {
/// Instances available from watching etcd
pub fn instances(&self) -> Vec<Instance> {
instances_inner(self.instance_source.as_ref())
match self.instance_source.as_ref() {
InstanceSource::Static => vec![],
InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
}
pub fn instance_ids(&self) -> Vec<i64> {
self.instances().into_iter().map(|ep| ep.id()).collect()
}
pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<i64>>> {
self.instance_avail.load()
}
/// Wait for at least one Instance to be available for this Endpoint
pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
let mut instances: Vec<Instance> = vec![];
......@@ -143,24 +125,51 @@ impl Client {
Ok(instances)
}
/// Instances available from watching etcd minus those reported as down
pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<i64>>> {
self.instance_cache.load()
/// Is this component know at startup and not discovered via etcd?
pub fn is_static(&self) -> bool {
matches!(self.instance_source.as_ref(), InstanceSource::Static)
}
/// Mark an instance as down/unavailable
pub fn report_instance_down(&self, instance_id: i64) {
self.instance_inhibited
.lock()
.unwrap()
.insert(instance_id, Instant::now());
let filtered = self
.instance_ids_avail()
.iter()
.filter_map(|&id| if id == instance_id { None } else { Some(id) })
.collect::<Vec<_>>();
self.instance_avail.store(Arc::new(filtered));
tracing::debug!("inhibiting instance {instance_id}");
}
/// Is this component know at startup and not discovered via etcd?
pub fn is_static(&self) -> bool {
matches!(self.instance_source.as_ref(), InstanceSource::Static)
/// Monitor the ETCD instance source and update instance_avail.
fn monitor_instance_source(&self) {
let cancel_token = self.endpoint.drt().primary_token();
let client = self.clone();
tokio::task::spawn(async move {
let mut rx = match client.instance_source.as_ref() {
InstanceSource::Static => {
tracing::error!("Static instance source is not watchable");
return;
}
InstanceSource::Dynamic(rx) => rx.clone(),
};
while !cancel_token.is_cancelled() {
let instance_ids: Vec<i64> = rx
.borrow_and_update()
.iter()
.map(|instance| instance.id())
.collect();
client.instance_avail.store(Arc::new(instance_ids));
tracing::debug!("instance source updated");
if let Err(err) = rx.changed().await {
tracing::error!("The Sender is dropped: {}", err);
cancel_token.cancel();
}
}
});
}
async fn get_or_create_dynamic_instance_source(
......@@ -253,49 +262,3 @@ impl Client {
Ok(instance_source)
}
}
/// Update the instance id cache
fn refresh_instances(
instance_source: &InstanceSource,
instance_inhibited: &Arc<Mutex<HashMap<i64, Instant>>>,
instance_cache: &Arc<ArcSwap<Vec<i64>>>,
) {
const ETCD_LEASE_TTL: u64 = 10; // seconds
// TODO: Can we get the remaining TTL from the lease for the instance?
let now = Instant::now();
let instances = instances_inner(instance_source);
let mut inhibited = instance_inhibited.lock().unwrap();
// 1. Remove inhibited instances that are no longer in `self.instances()`
// 2. Remove inhibited instances that have expired
// 3. Only return instances that are not inhibited after removals
let mut new_inhibited = HashMap::<i64, Instant>::new();
let filtered: Vec<i64> = instances
.into_iter()
.filter_map(|instance| {
let id = instance.id();
if let Some(&timestamp) = inhibited.get(&id) {
if now.duration_since(timestamp).as_secs() > ETCD_LEASE_TTL {
Some(id)
} else {
new_inhibited.insert(id, timestamp);
None
}
} else {
Some(id)
}
})
.collect();
*inhibited = new_inhibited;
instance_cache.store(Arc::new(filtered));
}
fn instances_inner(instance_source: &InstanceSource) -> Vec<Instance> {
match instance_source {
InstanceSource::Static => vec![],
InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
}
......@@ -178,20 +178,14 @@ where
Ok(stream) => {
let engine_ctx = stream.context();
let client = self.client.clone();
let stream = stream.then(move |res| {
let mut report_instance_down: Option<(Client, i64)> = None;
let stream = stream.map(move |res| {
if let Some(err) = res.err() {
const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
if format!("{:?}", err) == STREAM_ERR_MSG {
report_instance_down = Some((client.clone(), instance_id));
}
}
async move {
if let Some((client, instance_id)) = report_instance_down {
client.report_instance_down(instance_id);
}
res
}
res
});
Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
}
......
......@@ -151,11 +151,11 @@ impl<R> MaybeError for Annotated<R>
where
R: for<'de> Deserialize<'de> + Serialize,
{
fn from_err(err: Box<dyn std::error::Error>) -> Self {
fn from_err(err: Box<dyn std::error::Error + Send + Sync>) -> Self {
Annotated::from_error(format!("{:?}", err))
}
fn err(&self) -> Option<Box<dyn std::error::Error>> {
fn err(&self) -> Option<Box<dyn std::error::Error + Send + Sync>> {
if self.is_error() {
if let Some(comment) = &self.comment {
if !comment.is_empty() {
......
......@@ -17,10 +17,10 @@ use std::error::Error;
pub trait MaybeError {
/// Construct an instance from an error.
fn from_err(err: Box<dyn Error>) -> Self;
fn from_err(err: Box<dyn Error + Send + Sync>) -> Self;
/// Construct into an error instance.
fn err(&self) -> Option<Box<dyn Error>>;
fn err(&self) -> Option<Box<dyn Error + Send + Sync>>;
/// Check if the current instance represents a success.
fn is_ok(&self) -> bool {
......@@ -41,12 +41,12 @@ mod tests {
message: String,
}
impl MaybeError for TestError {
fn from_err(err: Box<dyn Error>) -> Self {
fn from_err(err: Box<dyn Error + Send + Sync>) -> Self {
TestError {
message: err.to_string(),
}
}
fn err(&self) -> Option<Box<dyn Error>> {
fn err(&self) -> Option<Box<dyn Error + Send + Sync>> {
Some(anyhow::Error::msg(self.message.clone()).into())
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment