Unverified Commit 1f07dab7 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Add migration to LLM requests (#1930)

parent 5f179186
...@@ -162,6 +162,11 @@ pub struct Flags { ...@@ -162,6 +162,11 @@ pub struct Flags {
#[arg(long)] #[arg(long)]
pub request_template: Option<PathBuf>, pub request_template: Option<PathBuf>,
/// How many times a request can be migrated to another worker if the HTTP server lost
/// connection to the current worker.
#[arg(long, value_parser = clap::value_parser!(u32).range(0..1024))]
pub migration_limit: Option<u32>,
/// Everything after a `--`. /// Everything after a `--`.
/// These are the command line arguments to the python engine when using `pystr` or `pytok`. /// These are the command line arguments to the python engine when using `pystr` or `pytok`.
#[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)] #[arg(index = 2, last = true, hide = true, allow_hyphen_values = true)]
...@@ -180,6 +185,9 @@ impl Flags { ...@@ -180,6 +185,9 @@ impl Flags {
if self.kv_cache_block_size.is_some() { if self.kv_cache_block_size.is_some() {
anyhow::bail!("'--kv-cache-block-size' flag should only be used on the worker node, not on the ingress"); anyhow::bail!("'--kv-cache-block-size' flag should only be used on the worker node, not on the ingress");
} }
if self.migration_limit.is_some() {
anyhow::bail!("'--migration-limit' flag should only be used on the worker node, not on the ingress");
}
} }
Output::EchoFull => {} Output::EchoFull => {}
Output::EchoCore => { Output::EchoCore => {
......
...@@ -45,7 +45,8 @@ pub async fn run( ...@@ -45,7 +45,8 @@ pub async fn run(
.context_length(flags.context_length) .context_length(flags.context_length)
.http_port(Some(flags.http_port)) .http_port(Some(flags.http_port))
.router_config(Some(flags.router_config())) .router_config(Some(flags.router_config()))
.request_template(flags.request_template.clone()); .request_template(flags.request_template.clone())
.migration_limit(flags.migration_limit);
// If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint. // If `in=dyn` we want the trtllm/sglang/vllm subprocess to listen on that endpoint.
// If not, then the endpoint isn't exposed so we let LocalModel invent one. // If not, then the endpoint isn't exposed so we let LocalModel invent one.
......
...@@ -48,6 +48,8 @@ pub async fn start( ...@@ -48,6 +48,8 @@ pub async fn start(
card.kv_cache_block_size.to_string(), card.kv_cache_block_size.to_string(),
"--context-length".to_string(), "--context-length".to_string(),
card.context_length.to_string(), card.context_length.to_string(),
"--migration-limit".to_string(),
card.migration_limit.to_string(),
]; ];
// TRTLLM only // TRTLLM only
// The worker node will only publish events and metrics if the router mode is KV // The worker node will only publish events and metrics if the router mode is KV
......
...@@ -42,6 +42,7 @@ class Config: ...@@ -42,6 +42,7 @@ class Config:
nnodes: int nnodes: int
node_rank: int node_rank: int
dist_init_addr: str dist_init_addr: str
migration_limit: int
extra_engine_args: str extra_engine_args: str
...@@ -202,7 +203,13 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -202,7 +203,13 @@ async def init(runtime: DistributedRuntime, config: Config):
model_type = ( model_type = (
ModelType.Backend if not engine_args.is_embedding else ModelType.Embedding ModelType.Backend if not engine_args.is_embedding else ModelType.Embedding
) )
await register_llm(model_type, endpoint, config.model_path, config.model_name) await register_llm(
model_type,
endpoint,
config.model_path,
config.model_name,
migration_limit=config.migration_limit,
)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes) # the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked # after the lease is revoked
...@@ -268,6 +275,12 @@ def cmd_line_args(): ...@@ -268,6 +275,12 @@ def cmd_line_args():
default="", default="",
help="Host address (e.g., `192.168.0.2:25000`) of the node with rank 0", help="Host address (e.g., `192.168.0.2:25000`) of the node with rank 0",
) )
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument( parser.add_argument(
"--extra-engine-args", "--extra-engine-args",
type=str, type=str,
...@@ -304,6 +317,7 @@ def cmd_line_args(): ...@@ -304,6 +317,7 @@ def cmd_line_args():
config.nnodes = args.nnodes config.nnodes = args.nnodes
config.node_rank = args.node_rank config.node_rank = args.node_rank
config.dist_init_addr = args.dist_init_addr config.dist_init_addr = args.dist_init_addr
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args config.extra_engine_args = args.extra_engine_args
return config return config
......
...@@ -122,6 +122,7 @@ class Config: ...@@ -122,6 +122,7 @@ class Config:
model_name: Optional[str] = None model_name: Optional[str] = None
tensor_parallel_size: int tensor_parallel_size: int
kv_block_size: int kv_block_size: int
migration_limit: int
extra_engine_args: str extra_engine_args: str
publish_events_and_metrics: bool publish_events_and_metrics: bool
disaggregation_mode: str disaggregation_mode: str
...@@ -136,6 +137,7 @@ class Config: ...@@ -136,6 +137,7 @@ class Config:
f"model_name={self.model_name}, " f"model_name={self.model_name}, "
f"tensor_parallel_size={self.tensor_parallel_size}, " f"tensor_parallel_size={self.tensor_parallel_size}, "
f"kv_block_size={self.kv_block_size}, " f"kv_block_size={self.kv_block_size}, "
f"migration_limit={self.migration_limit}, "
f"extra_engine_args={self.extra_engine_args}, " f"extra_engine_args={self.extra_engine_args}, "
f"publish_events_and_metrics={self.publish_events_and_metrics}, " f"publish_events_and_metrics={self.publish_events_and_metrics}, "
f"disaggregation_mode={self.disaggregation_mode}, " f"disaggregation_mode={self.disaggregation_mode}, "
...@@ -404,6 +406,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -404,6 +406,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.model_path, config.model_path,
config.model_name, config.model_name,
kv_cache_block_size=config.kv_block_size, kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
) )
# publisher will be set later if publishing is enabled. # publisher will be set later if publishing is enabled.
...@@ -476,6 +479,12 @@ def cmd_line_args(): ...@@ -476,6 +479,12 @@ def cmd_line_args():
default=None, default=None,
help="This argument is not used by TRTLLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.", help="This argument is not used by TRTLLM. Please provide max_input_len, max_seq_len and max_output_len in yaml file and point --extra-engine-args to the yaml file.",
) )
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument( parser.add_argument(
"--extra-engine-args", "--extra-engine-args",
type=str, type=str,
...@@ -557,6 +566,7 @@ def cmd_line_args(): ...@@ -557,6 +566,7 @@ def cmd_line_args():
config.endpoint = parsed_endpoint_name config.endpoint = parsed_endpoint_name
config.tensor_parallel_size = args.tensor_parallel_size config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size config.kv_block_size = args.kv_block_size
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args config.extra_engine_args = args.extra_engine_args
config.publish_events_and_metrics = args.publish_events_and_metrics config.publish_events_and_metrics = args.publish_events_and_metrics
config.disaggregation_mode = disaggregation_mode config.disaggregation_mode = disaggregation_mode
......
...@@ -56,6 +56,7 @@ class Config: ...@@ -56,6 +56,7 @@ class Config:
tensor_parallel_size: int tensor_parallel_size: int
kv_block_size: int kv_block_size: int
context_length: int context_length: int
migration_limit: int
extra_engine_args: str extra_engine_args: str
...@@ -233,6 +234,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -233,6 +234,7 @@ async def init(runtime: DistributedRuntime, config: Config):
"max_model_len", None "max_model_len", None
), # if None, takes length from tokenizer ), # if None, takes length from tokenizer
kv_cache_block_size=arg_map["block_size"], kv_cache_block_size=arg_map["block_size"],
migration_limit=config.migration_limit,
) )
handler = RequestHandler(component, engine_client, default_sampling_params) handler = RequestHandler(component, engine_client, default_sampling_params)
handler.setup_kv_metrics() handler.setup_kv_metrics()
...@@ -276,6 +278,12 @@ def cmd_line_args(): ...@@ -276,6 +278,12 @@ def cmd_line_args():
default=None, default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.", help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
) )
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument( parser.add_argument(
"--extra-engine-args", "--extra-engine-args",
type=str, type=str,
...@@ -308,6 +316,7 @@ def cmd_line_args(): ...@@ -308,6 +316,7 @@ def cmd_line_args():
config.tensor_parallel_size = args.tensor_parallel_size config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size config.kv_block_size = args.kv_block_size
config.context_length = args.context_length config.context_length = args.context_length
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args config.extra_engine_args = args.extra_engine_args
return config return config
......
...@@ -65,6 +65,7 @@ class Config: ...@@ -65,6 +65,7 @@ class Config:
tensor_parallel_size: int tensor_parallel_size: int
kv_block_size: int kv_block_size: int
context_length: int context_length: int
migration_limit: int
extra_engine_args: str extra_engine_args: str
...@@ -218,6 +219,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -218,6 +219,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config.model_path, config.model_path,
config.model_name, config.model_name,
kv_cache_block_size=config.kv_block_size, kv_cache_block_size=config.kv_block_size,
migration_limit=config.migration_limit,
) )
arg_map = { arg_map = {
...@@ -333,6 +335,12 @@ def cmd_line_args(): ...@@ -333,6 +335,12 @@ def cmd_line_args():
default=None, default=None,
help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.", help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
) )
parser.add_argument(
"--migration-limit",
type=int,
default=0,
help="Maximum number of times a request may be migrated to a different engine worker. The number may be overridden by the engine.",
)
parser.add_argument( parser.add_argument(
"--extra-engine-args", "--extra-engine-args",
type=str, type=str,
...@@ -365,6 +373,7 @@ def cmd_line_args(): ...@@ -365,6 +373,7 @@ def cmd_line_args():
config.tensor_parallel_size = args.tensor_parallel_size config.tensor_parallel_size = args.tensor_parallel_size
config.kv_block_size = args.kv_block_size config.kv_block_size = args.kv_block_size
config.context_length = args.context_length config.context_length = args.context_length
config.migration_limit = args.migration_limit
config.extra_engine_args = args.extra_engine_args config.extra_engine_args = args.extra_engine_args
return config return config
......
...@@ -131,7 +131,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32) ...@@ -131,7 +131,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
} }
#[pyfunction] #[pyfunction]
#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None))] #[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn register_llm<'p>( fn register_llm<'p>(
py: Python<'p>, py: Python<'p>,
...@@ -142,6 +142,7 @@ fn register_llm<'p>( ...@@ -142,6 +142,7 @@ fn register_llm<'p>(
context_length: Option<u32>, context_length: Option<u32>,
kv_cache_block_size: Option<u32>, kv_cache_block_size: Option<u32>,
router_mode: Option<RouterMode>, router_mode: Option<RouterMode>,
migration_limit: u32,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let model_type_obj = match model_type { let model_type_obj = match model_type {
ModelType::Chat => llm_rs::model_type::ModelType::Chat, ModelType::Chat => llm_rs::model_type::ModelType::Chat,
...@@ -162,7 +163,8 @@ fn register_llm<'p>( ...@@ -162,7 +163,8 @@ fn register_llm<'p>(
.model_name(model_name) .model_name(model_name)
.context_length(context_length) .context_length(context_length)
.kv_cache_block_size(kv_cache_block_size) .kv_cache_block_size(kv_cache_block_size)
.router_config(Some(router_config)); .router_config(Some(router_config))
.migration_limit(Some(migration_limit));
// Download from HF, load the ModelDeploymentCard // Download from HF, load the ModelDeploymentCard
let mut local_model = builder.build().await.map_err(to_pyerr)?; let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us // Advertise ourself on etcd so ingress can find us
......
...@@ -19,6 +19,7 @@ use dynamo_runtime::{ ...@@ -19,6 +19,7 @@ use dynamo_runtime::{
use crate::{ use crate::{
backend::Backend, backend::Backend,
kv_router::{KvPushRouter, KvRouterConfig}, kv_router::{KvPushRouter, KvRouterConfig},
migration::Migration,
model_type::ModelType, model_type::ModelType,
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, PreprocessedRequest}, preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, PreprocessedRequest},
protocols::common::llm_backend::{EmbeddingsEngineOutput, LLMEngineOutput}, protocols::common::llm_backend::{EmbeddingsEngineOutput, LLMEngineOutput},
...@@ -197,12 +198,14 @@ impl ModelWatcher { ...@@ -197,12 +198,14 @@ impl ModelWatcher {
// function. Needs checking carefully, possibly we need to store it in state. // function. Needs checking carefully, possibly we need to store it in state.
let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?); let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);
// Chat Completions
let frontend = SegmentSource::< let frontend = SegmentSource::<
SingleIn<NvCreateChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new(); >::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator(); let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator(); let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let migration = Migration::from_mdc(card.clone()).await?.into_operator();
let router = let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client( PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
client.clone(), client.clone(),
...@@ -231,19 +234,23 @@ impl ModelWatcher { ...@@ -231,19 +234,23 @@ impl ModelWatcher {
let chat_engine = frontend let chat_engine = frontend
.link(preprocessor.forward_edge())? .link(preprocessor.forward_edge())?
.link(backend.forward_edge())? .link(backend.forward_edge())?
.link(migration.forward_edge())?
.link(service_backend)? .link(service_backend)?
.link(migration.backward_edge())?
.link(backend.backward_edge())? .link(backend.backward_edge())?
.link(preprocessor.backward_edge())? .link(preprocessor.backward_edge())?
.link(frontend)?; .link(frontend)?;
self.manager self.manager
.add_chat_completions_model(&model_entry.name, chat_engine)?; .add_chat_completions_model(&model_entry.name, chat_engine)?;
// Completions
let frontend = SegmentSource::< let frontend = SegmentSource::<
SingleIn<NvCreateCompletionRequest>, SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<NvCreateCompletionResponse>>, ManyOut<Annotated<NvCreateCompletionResponse>>,
>::new(); >::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator(); let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator(); let backend = Backend::from_mdc(card.clone()).await?.into_operator();
let migration = Migration::from_mdc(card.clone()).await?.into_operator();
let router = let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client( PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client(
client, client,
...@@ -272,7 +279,9 @@ impl ModelWatcher { ...@@ -272,7 +279,9 @@ impl ModelWatcher {
let completions_engine = frontend let completions_engine = frontend
.link(preprocessor.forward_edge())? .link(preprocessor.forward_edge())?
.link(backend.forward_edge())? .link(backend.forward_edge())?
.link(migration.forward_edge())?
.link(service_backend)? .link(service_backend)?
.link(migration.backward_edge())?
.link(backend.backward_edge())? .link(backend.backward_edge())?
.link(preprocessor.backward_edge())? .link(preprocessor.backward_edge())?
.link(frontend)?; .link(frontend)?;
......
...@@ -22,6 +22,7 @@ pub mod hub; ...@@ -22,6 +22,7 @@ pub mod hub;
// pub mod key_value_store; // pub mod key_value_store;
pub mod kv_router; pub mod kv_router;
pub mod local_model; pub mod local_model;
pub mod migration;
pub mod mocker; pub mod mocker;
pub mod model_card; pub mod model_card;
pub mod model_type; pub mod model_type;
......
...@@ -46,6 +46,7 @@ pub struct LocalModelBuilder { ...@@ -46,6 +46,7 @@ pub struct LocalModelBuilder {
router_config: Option<RouterConfig>, router_config: Option<RouterConfig>,
kv_cache_block_size: u32, kv_cache_block_size: u32,
http_port: u16, http_port: u16,
migration_limit: u32,
} }
impl Default for LocalModelBuilder { impl Default for LocalModelBuilder {
...@@ -60,6 +61,7 @@ impl Default for LocalModelBuilder { ...@@ -60,6 +61,7 @@ impl Default for LocalModelBuilder {
context_length: Default::default(), context_length: Default::default(),
template_file: Default::default(), template_file: Default::default(),
router_config: Default::default(), router_config: Default::default(),
migration_limit: Default::default(),
} }
} }
} }
...@@ -112,6 +114,11 @@ impl LocalModelBuilder { ...@@ -112,6 +114,11 @@ impl LocalModelBuilder {
self self
} }
pub fn migration_limit(&mut self, migration_limit: Option<u32>) -> &mut Self {
self.migration_limit = migration_limit.unwrap_or(0);
self
}
/// Make an LLM ready for use: /// Make an LLM ready for use:
/// - Download it from Hugging Face (and NGC in future) if necessary /// - Download it from Hugging Face (and NGC in future) if necessary
/// - Resolve the path /// - Resolve the path
...@@ -137,10 +144,12 @@ impl LocalModelBuilder { ...@@ -137,10 +144,12 @@ impl LocalModelBuilder {
// echo_full engine doesn't need a path. It's an edge case, move it out of the way. // echo_full engine doesn't need a path. It's an edge case, move it out of the way.
if self.model_path.is_none() { if self.model_path.is_none() {
let mut card = ModelDeploymentCard::with_name_only(
self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
);
card.migration_limit = self.migration_limit;
return Ok(LocalModel { return Ok(LocalModel {
card: ModelDeploymentCard::with_name_only( card,
self.model_name.as_deref().unwrap_or(DEFAULT_NAME),
),
full_path: PathBuf::new(), full_path: PathBuf::new(),
endpoint_id, endpoint_id,
template, template,
...@@ -194,6 +203,8 @@ impl LocalModelBuilder { ...@@ -194,6 +203,8 @@ impl LocalModelBuilder {
card.context_length = context_length; card.context_length = context_length;
} }
card.migration_limit = self.migration_limit;
Ok(LocalModel { Ok(LocalModel {
card, card,
full_path, full_path,
......
This diff is collapsed.
...@@ -92,6 +92,7 @@ impl ModelDeploymentCard { ...@@ -92,6 +92,7 @@ impl ModelDeploymentCard {
last_published: None, last_published: None,
context_length, context_length,
kv_cache_block_size: 0, kv_cache_block_size: 0,
migration_limit: 0,
}) })
} }
...@@ -131,6 +132,7 @@ impl ModelDeploymentCard { ...@@ -131,6 +132,7 @@ impl ModelDeploymentCard {
last_published: None, last_published: None,
context_length, context_length,
kv_cache_block_size: 0, // set later kv_cache_block_size: 0, // set later
migration_limit: 0,
}) })
} }
} }
......
...@@ -127,6 +127,10 @@ pub struct ModelDeploymentCard { ...@@ -127,6 +127,10 @@ pub struct ModelDeploymentCard {
/// Size of a KV cache block - vllm only currently /// Size of a KV cache block - vllm only currently
/// Passed to the engine and the KV router. /// Passed to the engine and the KV router.
pub kv_cache_block_size: u32, pub kv_cache_block_size: u32,
/// How many times a request can be migrated to another worker if the HTTP server lost
/// connection to the current worker.
pub migration_limit: u32,
} }
impl ModelDeploymentCard { impl ModelDeploymentCard {
......
...@@ -136,11 +136,11 @@ impl LLMEngineOutput { ...@@ -136,11 +136,11 @@ impl LLMEngineOutput {
} }
impl MaybeError for LLMEngineOutput { impl MaybeError for LLMEngineOutput {
fn from_err(err: Box<dyn std::error::Error>) -> Self { fn from_err(err: Box<dyn std::error::Error + Send + Sync>) -> Self {
LLMEngineOutput::error(format!("{:?}", err)) LLMEngineOutput::error(format!("{:?}", err))
} }
fn err(&self) -> Option<Box<dyn std::error::Error>> { fn err(&self) -> Option<Box<dyn std::error::Error + Send + Sync>> {
if let Some(FinishReason::Error(err_msg)) = &self.finish_reason { if let Some(FinishReason::Error(err_msg)) = &self.finish_reason {
Some(anyhow::Error::msg(err_msg.clone()).into()) Some(anyhow::Error::msg(err_msg.clone()).into())
} else { } else {
......
...@@ -6,14 +6,8 @@ use crate::pipeline::{ ...@@ -6,14 +6,8 @@ use crate::pipeline::{
SingleIn, SingleIn,
}; };
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
use rand::Rng;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::RwLock; use std::sync::Arc;
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
};
use std::time::Instant;
use tokio::net::unix::pipe::Receiver; use tokio::net::unix::pipe::Receiver;
use crate::{ use crate::{
...@@ -48,10 +42,8 @@ pub struct Client { ...@@ -48,10 +42,8 @@ pub struct Client {
pub endpoint: Endpoint, pub endpoint: Endpoint,
// These are the remotes I know about from watching etcd // These are the remotes I know about from watching etcd
pub instance_source: Arc<InstanceSource>, pub instance_source: Arc<InstanceSource>,
// These are the instances that are reported as down from sending rpc // These are the instance source ids less those reported as down from sending rpc
instance_inhibited: Arc<Mutex<HashMap<i64, Instant>>>, instance_avail: Arc<ArcSwap<Vec<i64>>>,
// The current active IDs
instance_cache: Arc<ArcSwap<Vec<i64>>>,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
...@@ -60,16 +52,13 @@ pub enum InstanceSource { ...@@ -60,16 +52,13 @@ pub enum InstanceSource {
Dynamic(tokio::sync::watch::Receiver<Vec<Instance>>), Dynamic(tokio::sync::watch::Receiver<Vec<Instance>>),
} }
// TODO: Avoid returning a full clone of `Vec<Instance>` everytime from Client
// See instances() and instances_avail() methods
impl Client { impl Client {
// Client will only talk to a single static endpoint // Client will only talk to a single static endpoint
pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> { pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
Ok(Client { Ok(Client {
endpoint, endpoint,
instance_source: Arc::new(InstanceSource::Static), instance_source: Arc::new(InstanceSource::Static),
instance_inhibited: Arc::new(Mutex::new(HashMap::new())), instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_cache: Arc::new(ArcSwap::from(Arc::new(vec![]))),
}) })
} }
...@@ -85,26 +74,12 @@ impl Client { ...@@ -85,26 +74,12 @@ impl Client {
let instance_source = let instance_source =
Self::get_or_create_dynamic_instance_source(etcd_client, &endpoint).await?; Self::get_or_create_dynamic_instance_source(etcd_client, &endpoint).await?;
let cancel_token = endpoint.drt().primary_token();
let client = Client { let client = Client {
endpoint, endpoint,
instance_source, instance_source,
instance_inhibited: Arc::new(Mutex::new(HashMap::new())), instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
instance_cache: Arc::new(ArcSwap::from(Arc::new(vec![]))),
}; };
client.monitor_instance_source();
let instance_source_c = client.instance_source.clone();
let instance_inhibited_c = Arc::clone(&client.instance_inhibited);
let instance_cache_c = Arc::clone(&client.instance_cache);
tokio::task::spawn(async move {
while !cancel_token.is_cancelled() {
refresh_instances(&instance_source_c, &instance_inhibited_c, &instance_cache_c);
tokio::select! {
_ = cancel_token.cancelled() => {}
_ = tokio::time::sleep(INSTANCE_REFRESH_PERIOD) => {}
}
}
});
Ok(client) Ok(client)
} }
...@@ -119,13 +94,20 @@ impl Client { ...@@ -119,13 +94,20 @@ impl Client {
/// Instances available from watching etcd /// Instances available from watching etcd
pub fn instances(&self) -> Vec<Instance> { pub fn instances(&self) -> Vec<Instance> {
instances_inner(self.instance_source.as_ref()) match self.instance_source.as_ref() {
InstanceSource::Static => vec![],
InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
} }
pub fn instance_ids(&self) -> Vec<i64> { pub fn instance_ids(&self) -> Vec<i64> {
self.instances().into_iter().map(|ep| ep.id()).collect() self.instances().into_iter().map(|ep| ep.id()).collect()
} }
pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<i64>>> {
self.instance_avail.load()
}
/// Wait for at least one Instance to be available for this Endpoint /// Wait for at least one Instance to be available for this Endpoint
pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> { pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
let mut instances: Vec<Instance> = vec![]; let mut instances: Vec<Instance> = vec![];
...@@ -143,24 +125,51 @@ impl Client { ...@@ -143,24 +125,51 @@ impl Client {
Ok(instances) Ok(instances)
} }
/// Instances available from watching etcd minus those reported as down /// Is this component know at startup and not discovered via etcd?
pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<i64>>> { pub fn is_static(&self) -> bool {
self.instance_cache.load() matches!(self.instance_source.as_ref(), InstanceSource::Static)
} }
/// Mark an instance as down/unavailable /// Mark an instance as down/unavailable
pub fn report_instance_down(&self, instance_id: i64) { pub fn report_instance_down(&self, instance_id: i64) {
self.instance_inhibited let filtered = self
.lock() .instance_ids_avail()
.unwrap() .iter()
.insert(instance_id, Instant::now()); .filter_map(|&id| if id == instance_id { None } else { Some(id) })
.collect::<Vec<_>>();
self.instance_avail.store(Arc::new(filtered));
tracing::debug!("inhibiting instance {instance_id}"); tracing::debug!("inhibiting instance {instance_id}");
} }
/// Is this component know at startup and not discovered via etcd? /// Monitor the ETCD instance source and update instance_avail.
pub fn is_static(&self) -> bool { fn monitor_instance_source(&self) {
matches!(self.instance_source.as_ref(), InstanceSource::Static) let cancel_token = self.endpoint.drt().primary_token();
let client = self.clone();
tokio::task::spawn(async move {
let mut rx = match client.instance_source.as_ref() {
InstanceSource::Static => {
tracing::error!("Static instance source is not watchable");
return;
}
InstanceSource::Dynamic(rx) => rx.clone(),
};
while !cancel_token.is_cancelled() {
let instance_ids: Vec<i64> = rx
.borrow_and_update()
.iter()
.map(|instance| instance.id())
.collect();
client.instance_avail.store(Arc::new(instance_ids));
tracing::debug!("instance source updated");
if let Err(err) = rx.changed().await {
tracing::error!("The Sender is dropped: {}", err);
cancel_token.cancel();
}
}
});
} }
async fn get_or_create_dynamic_instance_source( async fn get_or_create_dynamic_instance_source(
...@@ -253,49 +262,3 @@ impl Client { ...@@ -253,49 +262,3 @@ impl Client {
Ok(instance_source) Ok(instance_source)
} }
} }
/// Update the instance id cache
fn refresh_instances(
instance_source: &InstanceSource,
instance_inhibited: &Arc<Mutex<HashMap<i64, Instant>>>,
instance_cache: &Arc<ArcSwap<Vec<i64>>>,
) {
const ETCD_LEASE_TTL: u64 = 10; // seconds
// TODO: Can we get the remaining TTL from the lease for the instance?
let now = Instant::now();
let instances = instances_inner(instance_source);
let mut inhibited = instance_inhibited.lock().unwrap();
// 1. Remove inhibited instances that are no longer in `self.instances()`
// 2. Remove inhibited instances that have expired
// 3. Only return instances that are not inhibited after removals
let mut new_inhibited = HashMap::<i64, Instant>::new();
let filtered: Vec<i64> = instances
.into_iter()
.filter_map(|instance| {
let id = instance.id();
if let Some(&timestamp) = inhibited.get(&id) {
if now.duration_since(timestamp).as_secs() > ETCD_LEASE_TTL {
Some(id)
} else {
new_inhibited.insert(id, timestamp);
None
}
} else {
Some(id)
}
})
.collect();
*inhibited = new_inhibited;
instance_cache.store(Arc::new(filtered));
}
fn instances_inner(instance_source: &InstanceSource) -> Vec<Instance> {
match instance_source {
InstanceSource::Static => vec![],
InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
}
...@@ -178,20 +178,14 @@ where ...@@ -178,20 +178,14 @@ where
Ok(stream) => { Ok(stream) => {
let engine_ctx = stream.context(); let engine_ctx = stream.context();
let client = self.client.clone(); let client = self.client.clone();
let stream = stream.then(move |res| { let stream = stream.map(move |res| {
let mut report_instance_down: Option<(Client, i64)> = None;
if let Some(err) = res.err() { if let Some(err) = res.err() {
const STREAM_ERR_MSG: &str = "Stream ended before generation completed"; const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
if format!("{:?}", err) == STREAM_ERR_MSG { if format!("{:?}", err) == STREAM_ERR_MSG {
report_instance_down = Some((client.clone(), instance_id));
}
}
async move {
if let Some((client, instance_id)) = report_instance_down {
client.report_instance_down(instance_id); client.report_instance_down(instance_id);
} }
res
} }
res
}); });
Ok(ResponseStream::new(Box::pin(stream), engine_ctx)) Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
} }
......
...@@ -151,11 +151,11 @@ impl<R> MaybeError for Annotated<R> ...@@ -151,11 +151,11 @@ impl<R> MaybeError for Annotated<R>
where where
R: for<'de> Deserialize<'de> + Serialize, R: for<'de> Deserialize<'de> + Serialize,
{ {
fn from_err(err: Box<dyn std::error::Error>) -> Self { fn from_err(err: Box<dyn std::error::Error + Send + Sync>) -> Self {
Annotated::from_error(format!("{:?}", err)) Annotated::from_error(format!("{:?}", err))
} }
fn err(&self) -> Option<Box<dyn std::error::Error>> { fn err(&self) -> Option<Box<dyn std::error::Error + Send + Sync>> {
if self.is_error() { if self.is_error() {
if let Some(comment) = &self.comment { if let Some(comment) = &self.comment {
if !comment.is_empty() { if !comment.is_empty() {
......
...@@ -17,10 +17,10 @@ use std::error::Error; ...@@ -17,10 +17,10 @@ use std::error::Error;
pub trait MaybeError { pub trait MaybeError {
/// Construct an instance from an error. /// Construct an instance from an error.
fn from_err(err: Box<dyn Error>) -> Self; fn from_err(err: Box<dyn Error + Send + Sync>) -> Self;
/// Construct into an error instance. /// Construct into an error instance.
fn err(&self) -> Option<Box<dyn Error>>; fn err(&self) -> Option<Box<dyn Error + Send + Sync>>;
/// Check if the current instance represents a success. /// Check if the current instance represents a success.
fn is_ok(&self) -> bool { fn is_ok(&self) -> bool {
...@@ -41,12 +41,12 @@ mod tests { ...@@ -41,12 +41,12 @@ mod tests {
message: String, message: String,
} }
impl MaybeError for TestError { impl MaybeError for TestError {
fn from_err(err: Box<dyn Error>) -> Self { fn from_err(err: Box<dyn Error + Send + Sync>) -> Self {
TestError { TestError {
message: err.to_string(), message: err.to_string(),
} }
} }
fn err(&self) -> Option<Box<dyn Error>> { fn err(&self) -> Option<Box<dyn Error + Send + Sync>> {
Some(anyhow::Error::msg(self.message.clone()).into()) Some(anyhow::Error::msg(self.message.clone()).into())
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment