Unverified Commit 1f07dab7 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Add migration to LLM requests (#1930)

parent 5f179186
......@@ -162,6 +162,11 @@ pub struct Flags {
#[arg(long)]
pub request_template: Option<PathBuf>,
/// How many times a request can be migrated to another worker if the HTTP server lost
/// connection to the current worker.
#[arg(long, value_parser = clap::value_parser!(u32).range(0..1024))]
pub migration_limit: Option<u32>,
/// Everything after a `--`.
/// These are the command line arguments to the python engine when using `pystr` or `pytok`.
#[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)]
......@@ -180,6 +185,9 @@ impl Flags {
if self.kv_cache_block_size.is_some() {
anyhow::bail!("'--kv-cache-block-size' flag should only be used on the worker node, not on the ingress");
}
if self.migration_limit.is_some() {
anyhow::bail!("'--migration-limit' flag should only be used on the worker node, not on the ingress");
}
}
Output::EchoFull => {}
Output::EchoCore => {
......
......@@ -45,7 +45,8 @@ pub async fn run(
.context_length(flags.context_length)
.http_port(Some(flags.http_port))
.router_config(Some(flags.router_config()))
.request_template(flags.request_template.clone());
.request_template(flags.request_template.clone())
.migration_limit(flags.migration_limit);
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we let LocalModel invent one.
......
......@@ -48,6 +48,8 @@ pub async fn start(
card.kv_cache_block_size.to_string(),
"--context-length".to_string(),
card.context_length.to_string(),
"--migration-limit".to_string(),
card.migration_limit.to_string(),
];
// TRTLLM only
// The worker node will only publish events and metrics if the router mode is KV
......
......@@ -42,6 +42,7 @@ class Config:
nnodes: int
node_rank: int
dist_init_addr: str
migration_limit: int
extra_engine_args: str
......@@ -202,7 +203,13 @@ async def init(runtime: DistributedRuntime, config: Config):
model_type = (
ModelType.Backend if not engine_args.is_embedding else ModelType.Embedding
)
await register_llm(model_type, endpoint, config.model_path, config.model_name)
await register_llm(
model_type,
endpoint,
config.model_path,
config.model_name,
migration_limit=config.migration_limit,
)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
......@@ -268,6 +275,12 @@ def cmd_line_args():
default="",
help="Host address (e.g., `192.168.0.2:25000`) of the node with rank 0",
)
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
......@@ -304,6 +317,7 @@ def cmd_line_args():
config.nnodes = args.nnodes
config.node_rank = args.node_rank
config.dist_init_addr = args.dist_init_addr
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args
return config
......
......@@ -122,6 +122,7 @@ class Config:
model_name: Optional[str] = None
tensor_parallel_size: int
kv_block_size: int
migration_limit: int
extra_engine_args: str
publish_events_and_metrics: bool
disaggregation_mode: str
......@@ -136,6 +137,7 @@ class Config:
f"model_name={self.model_name}, "
f"tensor_parallel_size={self.tensor_parallel_size}, "
f"kv_block_size={self.kv_block_size}, "
f"migration_limit={self.migration_limit}, "
f"extra_engine_args={self.extra_engine_args}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, "
......@@ -404,6 +406,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.model_path,
config.model_name,
kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
)
# publisher will be set later if publishing is enabled.
......@@ -476,6 +479,12 @@ def cmd_line_args():
default=None,
help="This argument is not used by TRTLLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
)
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
......@@ -557,6 +566,7 @@ def cmd_line_args():
config.endpoint = parsed_endpoint_name
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args
config.publish_events_and_metrics = args.publish_events_and_metrics
config.disaggregation_mode = disaggregation_mode
......
......@@ -56,6 +56,7 @@ class Config:
tensor_parallel_size: int
kv_block_size: int
context_length: int
migration_limit: int
extra_engine_args: str
......@@ -233,6 +234,7 @@ async def init(runtime: DistributedRuntime, config: Config):
"max_model_len", None
), # if None, takes length from tokenizer
kv_cache_block_size=arg_map["block_size"],
migration_limit=config.migration_limit,
)
handler = RequestHandler(component, engine_client, default_sampling_params)
handler.setup_kv_metrics()
......@@ -276,6 +278,12 @@ def cmd_line_args():
default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
)
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
......@@ -308,6 +316,7 @@ def cmd_line_args():
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.context_length = args.context_length
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args
return config
......
......@@ -65,6 +65,7 @@ class Config:
tensor_parallel_size: int
kv_block_size: int
context_length: int
migration_limit: int
extra_engine_args: str
......@@ -218,6 +219,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.model_path,
config.model_name,
kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
)
arg_map = {
......@@ -333,6 +335,12 @@ def cmd_line_args():
default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
)
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument(
"--extra-engine-args",
type=str,
......@@ -365,6 +373,7 @@ def cmd_line_args():
config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size
config.context_length = args.context_length
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args
return config
......
......@@ -131,7 +131,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
}
#[pyfunction]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None))]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0))]
#[allow(clippy::too_many_arguments)]
fn register_llm<'p>(
py: Python<'p>,
......@@ -142,6 +142,7 @@ fn register_llm<'p>(
context_length: Option<u32>,
kv_cache_block_size: Option<u32>,
router_mode: Option<RouterMode>,
migration_limit: u32,
) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type {
ModelType::Chat => llm_rs::model_type::ModelType::Chat,
......@@ -162,7 +163,8 @@ fn register_llm<'p>(
.model_name(model_name)
.context_length(context_length)
.kv_cache_block_size(kv_cache_block_size)
.router_config(Some(router_config));
.router_config(Some(router_config))
.migration_limit(Some(migration_limit));
// Download from HF, load the ModelDeploymentCard
let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us
......
......@@ -19,6 +19,7 @@ use dynamo_runtime::{
use crate::{
backend::Backend,
kv_router::{KvPushRouter, KvRouterConfig},
migration::Migration,
model_type::ModelType,
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, PreprocessedRequest},
protocols::common::llm_backend::{EmbeddingsEngineOutput, LLMEngineOutput},
......@@ -197,12 +198,14 @@ impl ModelWatcher {
// function. Needs checking carefully, possibly we need to store it in state.
let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);
// Chat Completions
let frontend = SegmentSource::<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let migration = Migration::from_mdc(card.clone()).await?.into_operator();
let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
client.clone(),
......@@ -231,19 +234,23 @@ impl ModelWatcher {
let chat_engine = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(migration.forward_edge())?
.link(service_backend)?
.link(migration.backward_edge())?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
self.manager
.add_chat_completions_model(&model_entry.name, chat_engine)?;
// Completions
let frontend = SegmentSource::<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let migration = Migration::from_mdc(card.clone()).await?.into_operator();
let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
client,
......@@ -272,7 +279,9 @@ impl ModelWatcher {
let completions_engine = frontend
.link(preprocessor.forward_edge())?
.link(backend.forward_edge())?
.link(migration.forward_edge())?
.link(service_backend)?
.link(migration.backward_edge())?
.link(backend.backward_edge())?
.link(preprocessor.backward_edge())?
.link(frontend)?;
......
......@@ -22,6 +22,7 @@ pub mod hub;
// pub mod key_value_store;
pub mod kv_router;
pub mod local_model;
pub mod migration;
pub mod mocker;
pub mod model_card;
pub mod model_type;
......
......@@ -46,6 +46,7 @@ pub struct LocalModelBuilder {
router_config: Option<RouterConfig>,
kv_cache_block_size: u32,
http_port: u16,
migration_limit: u32,
}
impl Default for LocalModelBuilder {
......@@ -60,6 +61,7 @@ impl Default for LocalModelBuilder {
context_length: Default::default(),
template_file: Default::default(),
router_config: Default::default(),
migration_limit: Default::default(),
}
}
}
......@@ -112,6 +114,11 @@ impl LocalModelBuilder {
self
}
pub fn migration_limit(&mut self, migration_limit: Option<u32>) -> &mut Self {
self.migration_limit = migration_limit.unwrap_or(0);
self
}
/// Make an LLM ready for use:
/// - Download it from Hugging Face (and NGC in future) if necessary
/// - Resolve the path
......@@ -137,10 +144,12 @@ impl LocalModelBuilder {
// echo_full engine doesn't need a path. It's an edge case, move it out of the way.
if self.model_path.is_none() {
let mut card = ModelDeploymentCard::with_name_only(
self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
);
card.migration_limit = self.migration_limit;
return Ok(LocalModel {
card: ModelDeploymentCard::with_name_only(
self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
),
card,
full_path: PathBuf::new(),
endpoint_id,
template,
......@@ -194,6 +203,8 @@ impl LocalModelBuilder {
card.context_length = context_length;
}
card.migration_limit = self.migration_limit;
Ok(LocalModel {
card,
full_path,
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use anyhow::{Error, Result};
use futures::{stream, stream::StreamExt};
use async_nats::client::{
RequestError as NatsRequestError, RequestErrorKind::NoResponders as NatsNoResponders,
};
use crate::{
model_card::model::ModelDeploymentCard,
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
};
use dynamo_runtime::{
pipeline::{
async_trait, AsyncEngineContextProvider, ManyOut, Operator, ResponseStream,
ServerStreamingEngine, SingleIn,
},
protocols::{annotated::Annotated, maybe_error::MaybeError},
};
pub struct Migration {
migration_limit: u32,
}
impl Migration {
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
migration_limit: mdc.migration_limit,
}))
}
}
#[async_trait]
impl
Operator<
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>,
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>,
> for Migration
{
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
let (preprocessed_request, context) = request.transfer(());
let engine_ctx = context.context();
let retry_manager =
RetryManager::build(preprocessed_request, next, self.migration_limit).await?;
let response_stream = stream::unfold(retry_manager, |mut retry_manager| async move {
retry_manager
.next()
.await
.map(|response| (response, retry_manager))
});
Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx))
}
}
struct RetryManager {
request: PreprocessedRequest,
next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
next_stream: Option<ManyOut<Annotated<LLMEngineOutput>>>,
retries_left: u32,
}
impl RetryManager {
pub async fn build(
preprocessed_request: PreprocessedRequest,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
retries_left: u32,
) -> Result<Self> {
let mut slf = Self {
request: preprocessed_request,
next_generate: next,
next_stream: None,
retries_left: retries_left + 1, // +1 to account for the initial attempt
};
slf.new_stream().await?;
Ok(slf)
}
pub async fn next(&mut self) -> Option<Annotated<LLMEngineOutput>> {
loop {
let response_stream = match self.next_stream.as_mut() {
Some(stream) => stream,
None => {
tracing::error!("next() called with next_stream is None - should not happen");
return Some(Annotated::from_err(
Error::msg("next_stream is None").into(),
));
}
};
if let Some(response) = response_stream.next().await {
if let Some(err) = response.err() {
const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
if format!("{:?}", err) == STREAM_ERR_MSG {
tracing::info!("Stream disconnected... recreating stream...");
if let Err(err) = self.new_stream().await {
tracing::info!("Cannot recreate stream: {:?}", err);
} else {
continue;
}
}
}
self.track_response(&response);
return Some(response);
}
return None;
}
}
async fn new_stream(&mut self) -> Result<()> {
let mut response_stream: Option<Result<ManyOut<Annotated<LLMEngineOutput>>>> = None;
while self.retries_left > 0 {
self.retries_left -= 1;
// TODO: Is there anything needed to pass between context?
let request = SingleIn::new(self.request.clone());
response_stream = Some(self.next_generate.generate(request).await);
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err() {
if let Some(req_err) = err.downcast_ref::<NatsRequestError>() {
if matches!(req_err.kind(), NatsNoResponders) {
tracing::info!("Creating new stream... retrying...");
continue;
}
}
}
break;
}
match response_stream {
Some(Ok(next_stream)) => {
self.next_stream = Some(next_stream);
Ok(())
}
Some(Err(err)) => Err(err), // should propagate streaming error if stream started
None => Err(Error::msg(
"Retries exhausted - should propagate streaming error",
)),
}
}
fn track_response(&mut self, response: &Annotated<LLMEngineOutput>) {
if self.retries_left == 0 {
return;
}
let llm_engine_output = match response.data.as_ref() {
Some(output) => output,
None => return,
};
for token_id in llm_engine_output.token_ids.iter() {
self.request.token_ids.push(*token_id);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::common::{SamplingOptions, StopConditions};
use dynamo_runtime::pipeline::context::Controller;
use dynamo_runtime::pipeline::AsyncEngine;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::sync::mpsc;
// Helper to create a mock preprocessed request
fn create_mock_request() -> PreprocessedRequest {
PreprocessedRequest {
token_ids: vec![1, 2, 3],
batch_token_ids: None,
stop_conditions: StopConditions::default(),
sampling_options: SamplingOptions::default(),
eos_token_ids: vec![],
mdc_sum: None,
annotations: vec![],
estimated_prefix_hit_num_blocks: None,
}
}
// Helper to create mock LLM engine output
fn create_mock_output(token_id: u32) -> Annotated<LLMEngineOutput> {
Annotated::from_data(LLMEngineOutput {
token_ids: vec![token_id],
tokens: None,
text: Some(format!("token_{}", token_id)),
cum_log_probs: None,
log_probs: None,
finish_reason: None,
index: None,
})
}
#[derive(Debug, Clone)]
enum MockBehavior {
/// Always succeeds with all responses
Success,
/// Fails on first call with NoResponders error, then succeeds on subsequent calls
FailThenSuccess,
/// Succeeds initially, fails mid-stream with specific error, then succeeds on retry
MidStreamFail { fail_after: usize },
/// Succeeds initially, fails mid-stream with specific error, then always fails on retry attempts
MidStreamFailAlways { fail_after: usize },
/// Succeeds initially, fails mid-stream, then always fails with stream error on retry attempts
MidStreamFailAlwaysStreamError { fail_after: usize },
/// Always fails with NoResponders error (same as FailThenSuccess first call)
AlwaysFail,
}
// Unified mock server streaming engine that can simulate different scenarios
struct MockEngine {
behavior: MockBehavior,
num_responses: usize,
token_offset: u32,
call_count: Arc<AtomicU32>,
}
impl MockEngine {
fn new(behavior: MockBehavior, num_responses: usize, token_offset: u32) -> Self {
Self {
behavior,
num_responses,
token_offset,
call_count: Arc::new(AtomicU32::new(0)),
}
}
}
#[async_trait]
impl
AsyncEngine<
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>,
anyhow::Error,
> for MockEngine
{
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
let call_num = self.call_count.fetch_add(1, Ordering::SeqCst);
let (preprocessed_request, _) = request.transfer(());
// Calculate how many responses we've already generated based on request token_ids
// Initial request has [1, 2, 3], so anything beyond that are generated responses
let initial_tokens = 3; // [1, 2, 3]
let responses_already_generated = preprocessed_request
.token_ids
.len()
.saturating_sub(initial_tokens);
let _responses_remaining = self
.num_responses
.saturating_sub(responses_already_generated);
match &self.behavior {
MockBehavior::Success => {
// Always succeed with remaining responses
self.send_responses(responses_already_generated, self.num_responses)
.await
}
MockBehavior::FailThenSuccess => {
if call_num == 0 {
// First call - return "No responders available" error to trigger retry
let nats_error: NatsRequestError = NatsNoResponders.into();
return Err(nats_error.into());
} else {
// Subsequent calls - succeed with remaining responses
self.send_responses(responses_already_generated, self.num_responses)
.await
}
}
MockBehavior::MidStreamFail { fail_after } => {
let (tx, rx) = mpsc::channel(1);
let token_offset = self.token_offset;
let fail_after = *fail_after;
let num_responses = self.num_responses;
if call_num == 0 {
// First call - send some responses then an error to simulate disconnection
tokio::spawn(async move {
// Send responses from current position to fail_after
for i in responses_already_generated..fail_after.min(num_responses) {
let response = create_mock_output(token_offset + 1 + i as u32);
if tx.send(response).await.is_err() {
break;
}
}
// Send the specific error that triggers retry logic
let error_response = Annotated::from_err(
anyhow::Error::msg("Stream ended before generation completed")
.into(),
);
let _ = tx.send(error_response).await;
});
} else {
// Second call - send remaining responses from where we left off
tokio::spawn(async move {
for i in responses_already_generated..num_responses {
let response = create_mock_output(token_offset + 1 + i as u32);
if tx.send(response).await.is_err() {
break;
}
}
});
}
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
))
}
MockBehavior::MidStreamFailAlways { fail_after } => {
if call_num == 0 {
// First call - send some responses then an error to simulate disconnection
let (tx, rx) = mpsc::channel(1);
let token_offset = self.token_offset;
let fail_after = *fail_after;
let num_responses = self.num_responses;
tokio::spawn(async move {
// Send responses from current position to fail_after
for i in responses_already_generated..fail_after.min(num_responses) {
let response = create_mock_output(token_offset + 1 + i as u32);
if tx.send(response).await.is_err() {
break;
}
}
// Send the specific error that triggers retry logic
let error_response = Annotated::from_err(
anyhow::Error::msg("Stream ended before generation completed")
.into(),
);
let _ = tx.send(error_response).await;
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
))
} else {
// Subsequent calls - always fail with NoResponders error (same as AlwaysFail)
let nats_error: NatsRequestError = NatsNoResponders.into();
Err(nats_error.into())
}
}
MockBehavior::MidStreamFailAlwaysStreamError { fail_after } => {
let (tx, rx) = mpsc::channel(1);
let token_offset = self.token_offset;
let fail_after = *fail_after;
let num_responses = self.num_responses;
if call_num == 0 {
// First call - send some responses then an error to simulate disconnection
tokio::spawn(async move {
// Send responses from current position to fail_after
for i in responses_already_generated..fail_after.min(num_responses) {
let response = create_mock_output(token_offset + 1 + i as u32);
if tx.send(response).await.is_err() {
break;
}
}
// Send the specific error that triggers retry logic
let error_response = Annotated::from_err(
anyhow::Error::msg("Stream ended before generation completed")
.into(),
);
let _ = tx.send(error_response).await;
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
))
} else {
// Subsequent calls - immediately send stream error (no successful responses)
tokio::spawn(async move {
// Send the stream error immediately
let error_response = Annotated::from_err(
anyhow::Error::msg("Stream ended before generation completed")
.into(),
);
let _ = tx.send(error_response).await;
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
))
}
}
MockBehavior::AlwaysFail => {
// Always fail with NoResponders error (same as FailThenSuccess first call)
let nats_error: NatsRequestError = NatsNoResponders.into();
Err(nats_error.into())
}
}
}
}
impl MockEngine {
async fn send_responses(
&self,
start: usize,
end: usize,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
let (tx, rx) = mpsc::channel(1);
let token_offset = self.token_offset;
tokio::spawn(async move {
for i in start..end {
let response = create_mock_output(token_offset + 1 + i as u32);
if tx.send(response).await.is_err() {
break;
}
}
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
))
}
}
/// Test case 1: No migration needed
/// Tests the normal case where the RetryManager successfully processes all responses
/// from a single stream without any failures or need for retries/migration.
/// Expected behavior: All 10 responses should be received successfully.
#[tokio::test]
async fn test_retry_manager_no_migration() {
let request = create_mock_request();
let mock_engine = Arc::new(MockEngine::new(MockBehavior::Success, 10, 100));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(request, next_generate, 0)
.await
.expect("Failed to build RetryManager");
let mut responses = Vec::new();
while let Some(response) = retry_manager.next().await {
responses.push(response);
}
assert_eq!(responses.len(), 10);
for (i, response) in responses.iter().enumerate() {
assert!(response.err().is_none());
if let Some(output) = &response.data {
assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110
}
}
}
/// Test case 2: New request migration
/// Tests the scenario where a worker becomes unreachable for new requests initially,
/// triggering the RetryManager to retry the request. The MockEngine with FailThenSuccess
/// fails on the first call with a "No responders available" error, then succeeds
/// on subsequent calls, simulating a worker becoming available after initial failure.
/// Expected behavior: All 10 responses should be received successfully after retry.
#[tokio::test]
async fn test_retry_manager_new_request_migration() {
let request = create_mock_request();
let mock_engine = Arc::new(MockEngine::new(MockBehavior::FailThenSuccess, 10, 100));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(request, next_generate, 3)
.await
.expect("Failed to build RetryManager");
let mut responses = Vec::new();
while let Some(response) = retry_manager.next().await {
responses.push(response);
}
assert_eq!(responses.len(), 10);
for (i, response) in responses.iter().enumerate() {
assert!(response.err().is_none());
if let Some(output) = &response.data {
assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110
}
}
}
/// Test case 3: Ongoing request migration
/// Tests the scenario where a worker fails mid-stream during an ongoing request.
/// This simulates a connection being lost after partial response delivery, requiring
/// the RetryManager to detect the failure (via "Stream ended before generation completed" error),
/// create a new stream, and continue from where it left off.
/// Expected behavior: 5 responses from first stream + 5 responses from retry stream = 10 total.
#[tokio::test]
async fn test_retry_manager_ongoing_request_migration() {
let request = create_mock_request();
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFail { fail_after: 5 },
10,
100,
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(request, next_generate, 3)
.await
.expect("Failed to build RetryManager");
let mut responses = Vec::new();
while let Some(response) = retry_manager.next().await {
responses.push(response);
}
// Should have received all 10 responses (5 from first stream + 5 from second stream)
assert_eq!(responses.len(), 10);
// Check that we received responses from both streams
for (i, response) in responses.iter().enumerate() {
assert!(response.err().is_none());
if let Some(output) = &response.data {
assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103, ..., 110
}
}
}
/// Test case 4: New request migration - indefinite failure
/// Tests the scenario where a worker becomes unreachable for new requests indefinitely.
/// The RetryManager should exhaust all retries and return the original error from the first attempt.
/// Expected behavior: Should receive an error after all retries are exhausted, with the original error.
#[tokio::test]
async fn test_retry_manager_new_request_migration_indefinite_failure() {
let request = create_mock_request();
let mock_engine = Arc::new(MockEngine::new(MockBehavior::AlwaysFail, 0, 100));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
// Should fail to build due to initial stream creation failure after exhausting all 3 retries
let retry_manager_result = RetryManager::build(request, next_generate, 3).await;
assert!(retry_manager_result.is_err());
if let Err(error) = retry_manager_result {
assert!(error.to_string().contains("no responders"));
}
}
/// Test case 5: Ongoing request migration - indefinite failure
/// Tests the scenario where a worker fails mid-stream indefinitely during ongoing requests.
/// The RetryManager should exhaust all retries and return the original stream disconnection error.
/// Expected behavior: Should receive some responses from first stream, then error after retries exhausted.
#[tokio::test]
async fn test_retry_manager_ongoing_request_migration_indefinite_failure() {
let request = create_mock_request();
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFailAlways { fail_after: 3 },
10,
100,
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(request, next_generate, 3) // 3 retries
.await
.expect("Failed to build RetryManager");
let mut responses = Vec::new();
// Collect all responses (both successful and error responses)
while let Some(response) = retry_manager.next().await {
responses.push(response);
}
// Should have received 4 total responses: 3 successful + 1 error
assert_eq!(responses.len(), 4);
// First 3 responses should be successful with tokens 101, 102, 103
for (i, response) in responses[0..3].iter().enumerate() {
assert!(response.err().is_none());
if let Some(output) = &response.data {
assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103
}
}
// 4th response should be an error after retries are exhausted
let error_response = &responses[3];
assert!(error_response.err().is_some());
if let Some(error) = error_response.err() {
assert!(error
.to_string()
.contains("Stream ended before generation completed"));
}
}
/// Test case 6: Ongoing request migration - indefinite failure with stream errors
/// Tests the scenario where a worker fails mid-stream indefinitely during ongoing requests,
/// and all retry attempts also fail with stream errors instead of NATS errors.
/// Expected behavior: Should receive some responses from first stream, then error after retries exhausted.
#[tokio::test]
async fn test_retry_manager_ongoing_request_migration_indefinite_failure_stream_error() {
let request = create_mock_request();
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFailAlwaysStreamError { fail_after: 3 },
10,
100,
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(request, next_generate, 3) // 3 retries
.await
.expect("Failed to build RetryManager");
let mut responses = Vec::new();
// Collect all responses (both successful and error responses)
while let Some(response) = retry_manager.next().await {
responses.push(response);
}
// Should have received 4 total responses: 3 successful + 1 error
assert_eq!(responses.len(), 4);
// First 3 responses should be successful with tokens 101, 102, 103
for (i, response) in responses[0..3].iter().enumerate() {
assert!(response.err().is_none());
if let Some(output) = &response.data {
assert_eq!(output.token_ids, vec![101 + i as u32]); // 101, 102, 103
}
}
// 4th response should be an error after retries are exhausted
let error_response = &responses[3];
assert!(error_response.err().is_some());
if let Some(error) = error_response.err() {
assert!(error
.to_string()
.contains("Stream ended before generation completed"));
}
}
}
......@@ -92,6 +92,7 @@ impl ModelDeploymentCard {
last_published: None,
context_length,
kv_cache_block_size: 0,
migration_limit: 0,
})
}
......@@ -131,6 +132,7 @@ impl ModelDeploymentCard {
last_published: None,
context_length,
kv_cache_block_size: 0, // set later
migration_limit: 0,
})
}
}
......
......@@ -127,6 +127,10 @@ pub struct ModelDeploymentCard {
/// Size of a KV cache block - vllm only currently
/// Passed to the engine and the KV router.
pub kv_cache_block_size: u32,
/// How many times a request can be migrated to another worker if the HTTP server lost
/// connection to the current worker.
pub migration_limit: u32,
}
impl ModelDeploymentCard {
......
......@@ -136,11 +136,11 @@ impl LLMEngineOutput {
}
impl MaybeError for LLMEngineOutput {
fn from_err(err: Box<dyn std::error::Error>) -> Self {
fn from_err(err: Box<dyn std::error::Error + Send + Sync>) -> Self {
LLMEngineOutput::error(format!("{:?}", err))
}
fn err(&self) -> Option<Box<dyn std::error::Error>> {
fn err(&self) -> Option<Box<dyn std::error::Error + Send + Sync>> {
if let Some(FinishReason::Error(err_msg)) = &self.finish_reason {
Some(anyhow::Error::msg(err_msg.clone()).into())
} else {
......
......@@ -6,14 +6,8 @@ use crate::pipeline::{
SingleIn,
};
use arc_swap::ArcSwap;
use rand::Rng;
use std::collections::HashMap;
use std::sync::RwLock;
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
};
use std::time::Instant;
use std::sync::Arc;
use tokio::net::unix::pipe::Receiver;
use crate::{
......@@ -48,10 +42,8 @@ pub struct Client {
pub endpoint: Endpoint,
// These are the remotes I know about from watching etcd
pub instance_source: Arc<InstanceSource>,
// These are the instances that are reported as down from sending rpc
instance_inhibited: Arc<Mutex<HashMap<i64, Instant>>>,
// The current active IDs
instance_cache: Arc<ArcSwap<Vec<i64>>>,
// These are the instance source ids less those reported as down from sending rpc
instance_avail: Arc<ArcSwap<Vec<i64>>>,
}
#[derive(Clone, Debug)]
......@@ -60,16 +52,13 @@ pub enum InstanceSource {
Dynamic(tokio::sync::watch::Receiver<Vec<Instance>>),
}
// TODO: Avoid returning a full clone of `Vec<Instance>` everytime from Client
// See instances() and instances_avail() methods
impl Client {
// Client will only talk to a single static endpoint
pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
Ok(Client {
endpoint,
instance_source: Arc::new(InstanceSource::Static),
instance_inhibited: Arc::new(Mutex::new(HashMap::new())),
instance_cache: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
})
}
......@@ -85,26 +74,12 @@ impl Client {
let instance_source =
Self::get_or_create_dynamic_instance_source(etcd_client, &endpoint).await?;
let cancel_token = endpoint.drt().primary_token();
let client = Client {
endpoint,
instance_source,
instance_inhibited: Arc::new(Mutex::new(HashMap::new())),
instance_cache: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
};
let instance_source_c = client.instance_source.clone();
let instance_inhibited_c = Arc::clone(&client.instance_inhibited);
let instance_cache_c = Arc::clone(&client.instance_cache);
tokio::task::spawn(async move {
while !cancel_token.is_cancelled() {
refresh_instances(&instance_source_c, &instance_inhibited_c, &instance_cache_c);
tokio::select! {
_ = cancel_token.cancelled() => {}
_ = tokio::time::sleep(INSTANCE_REFRESH_PERIOD) => {}
}
}
});
client.monitor_instance_source();
Ok(client)
}
......@@ -119,13 +94,20 @@ impl Client {
/// Instances available from watching etcd
pub fn instances(&self) -> Vec<Instance> {
instances_inner(self.instance_source.as_ref())
match self.instance_source.as_ref() {
InstanceSource::Static => vec![],
InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
}
pub fn instance_ids(&self) -> Vec<i64> {
self.instances().into_iter().map(|ep| ep.id()).collect()
}
pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<i64>>> {
self.instance_avail.load()
}
/// Wait for at least one Instance to be available for this Endpoint
pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
let mut instances: Vec<Instance> = vec![];
......@@ -143,24 +125,51 @@ impl Client {
Ok(instances)
}
/// Instances available from watching etcd minus those reported as down
pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<i64>>> {
self.instance_cache.load()
/// Is this component know at startup and not discovered via etcd?
pub fn is_static(&self) -> bool {
matches!(self.instance_source.as_ref(), InstanceSource::Static)
}
/// Mark an instance as down/unavailable
pub fn report_instance_down(&self, instance_id: i64) {
self.instance_inhibited
.lock()
.unwrap()
.insert(instance_id, Instant::now());
let filtered = self
.instance_ids_avail()
.iter()
.filter_map(|&id| if id == instance_id { None } else { Some(id) })
.collect::<Vec<_>>();
self.instance_avail.store(Arc::new(filtered));
tracing::debug!("inhibiting instance {instance_id}");
}
/// Is this component know at startup and not discovered via etcd?
pub fn is_static(&self) -> bool {
matches!(self.instance_source.as_ref(), InstanceSource::Static)
/// Monitor the ETCD instance source and update instance_avail.
fn monitor_instance_source(&self) {
let cancel_token = self.endpoint.drt().primary_token();
let client = self.clone();
tokio::task::spawn(async move {
let mut rx = match client.instance_source.as_ref() {
InstanceSource::Static => {
tracing::error!("Static instance source is not watchable");
return;
}
InstanceSource::Dynamic(rx) => rx.clone(),
};
while !cancel_token.is_cancelled() {
let instance_ids: Vec<i64> = rx
.borrow_and_update()
.iter()
.map(|instance| instance.id())
.collect();
client.instance_avail.store(Arc::new(instance_ids));
tracing::debug!("instance source updated");
if let Err(err) = rx.changed().await {
tracing::error!("The Sender is dropped: {}", err);
cancel_token.cancel();
}
}
});
}
async fn get_or_create_dynamic_instance_source(
......@@ -253,49 +262,3 @@ impl Client {
Ok(instance_source)
}
}
/// Update the instance id cache
fn refresh_instances(
instance_source: &InstanceSource,
instance_inhibited: &Arc<Mutex<HashMap<i64, Instant>>>,
instance_cache: &Arc<ArcSwap<Vec<i64>>>,
) {
const ETCD_LEASE_TTL: u64 = 10; // seconds
// TODO: Can we get the remaining TTL from the lease for the instance?
let now = Instant::now();
let instances = instances_inner(instance_source);
let mut inhibited = instance_inhibited.lock().unwrap();
// 1. Remove inhibited instances that are no longer in `self.instances()`
// 2. Remove inhibited instances that have expired
// 3. Only return instances that are not inhibited after removals
let mut new_inhibited = HashMap::<i64, Instant>::new();
let filtered: Vec<i64> = instances
.into_iter()
.filter_map(|instance| {
let id = instance.id();
if let Some(&timestamp) = inhibited.get(&id) {
if now.duration_since(timestamp).as_secs() > ETCD_LEASE_TTL {
Some(id)
} else {
new_inhibited.insert(id, timestamp);
None
}
} else {
Some(id)
}
})
.collect();
*inhibited = new_inhibited;
instance_cache.store(Arc::new(filtered));
}
fn instances_inner(instance_source: &InstanceSource) -> Vec<Instance> {
match instance_source {
InstanceSource::Static => vec![],
InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
}
......@@ -178,20 +178,14 @@ where
Ok(stream) => {
let engine_ctx = stream.context();
let client = self.client.clone();
let stream = stream.then(move |res| {
let mut report_instance_down: Option<(Client, i64)> = None;
let stream = stream.map(move |res| {
if let Some(err) = res.err() {
const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
if format!("{:?}", err) == STREAM_ERR_MSG {
report_instance_down = Some((client.clone(), instance_id));
}
}
async move {
if let Some((client, instance_id)) = report_instance_down {
client.report_instance_down(instance_id);
}
res
}
res
});
Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
}
......
......@@ -151,11 +151,11 @@ impl<R> MaybeError for Annotated<R>
where
R: for<'de> Deserialize<'de> + Serialize,
{
fn from_err(err: Box<dyn std::error::Error>) -> Self {
fn from_err(err: Box<dyn std::error::Error + Send + Sync>) -> Self {
Annotated::from_error(format!("{:?}", err))
}
fn err(&self) -> Option<Box<dyn std::error::Error>> {
fn err(&self) -> Option<Box<dyn std::error::Error + Send + Sync>> {
if self.is_error() {
if let Some(comment) = &self.comment {
if !comment.is_empty() {
......
......@@ -17,10 +17,10 @@ use std::error::Error;
pub trait MaybeError {
/// Construct an instance from an error.
fn from_err(err: Box<dyn Error>) -> Self;
fn from_err(err: Box<dyn Error + Send + Sync>) -> Self;
/// Construct into an error instance.
fn err(&self) -> Option<Box<dyn Error>>;
fn err(&self) -> Option<Box<dyn Error + Send + Sync>>;
/// Check if the current instance represents a success.
fn is_ok(&self) -> bool {
......@@ -41,12 +41,12 @@ mod tests {
message: String,
}
impl MaybeError for TestError {
fn from_err(err: Box<dyn Error>) -> Self {
fn from_err(err: Box<dyn Error + Send + Sync>) -> Self {
TestError {
message: err.to_string(),
}
}
fn err(&self) -> Option<Box<dyn Error>> {
fn err(&self) -> Option<Box<dyn Error + Send + Sync>> {
Some(anyhow::Error::msg(self.message.clone()).into())
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment