Unverified Commit 4b7a806c authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: add prefill workers to discovery (#3709)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent ca3daddc
...@@ -217,6 +217,20 @@ fn register_llm<'p>( ...@@ -217,6 +217,20 @@ fn register_llm<'p>(
user_data: Option<&Bound<'p, PyDict>>, user_data: Option<&Bound<'p, PyDict>>,
custom_template_path: Option<&str>, custom_template_path: Option<&str>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
// Validate Prefill model type requirements
if model_type.inner == llm_rs::model_type::ModelType::Prefill {
if !matches!(model_input, ModelInput::Tokens) {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"ModelType::Prefill requires model_input to be ModelInput::Tokens",
));
}
if migration_limit != 0 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"ModelType::Prefill requires migration_limit to be 0",
));
}
}
let model_input = match model_input { let model_input = match model_input {
ModelInput::Text => llm_rs::model_type::ModelInput::Text, ModelInput::Text => llm_rs::model_type::ModelInput::Text,
ModelInput::Tokens => llm_rs::model_type::ModelInput::Tokens, ModelInput::Tokens => llm_rs::model_type::ModelInput::Tokens,
...@@ -370,6 +384,10 @@ impl ModelType { ...@@ -370,6 +384,10 @@ impl ModelType {
const TensorBased: Self = ModelType { const TensorBased: Self = ModelType {
inner: llm_rs::model_type::ModelType::TensorBased, inner: llm_rs::model_type::ModelType::TensorBased,
}; };
#[classattr]
const Prefill: Self = ModelType {
inner: llm_rs::model_type::ModelType::Prefill,
};
fn __or__(&self, other: &Self) -> Self { fn __or__(&self, other: &Self) -> Self {
ModelType { ModelType {
......
...@@ -860,7 +860,12 @@ class ModelInput: ...@@ -860,7 +860,12 @@ class ModelInput:
... ...
class ModelType: class ModelType:
"""What type of request this model needs: Chat, Completions, Embedding or Tensor""" """What type of request this model needs: Chat, Completions, Embedding, Tensor or Prefill"""
Chat: ModelType
Completions: ModelType
Embedding: ModelType
TensorBased: ModelType
Prefill: ModelType
... ...
class RouterMode: class RouterMode:
...@@ -883,8 +888,9 @@ async def register_llm( ...@@ -883,8 +888,9 @@ async def register_llm(
model_name: Optional[str] = None, model_name: Optional[str] = None,
context_length: Optional[int] = None, context_length: Optional[int] = None,
kv_cache_block_size: Optional[int] = None, kv_cache_block_size: Optional[int] = None,
migration_limit: int = 0,
router_mode: Optional[RouterMode] = None, router_mode: Optional[RouterMode] = None,
migration_limit: int = 0,
runtime_config: Optional[ModelRuntimeConfig] = None,
user_data: Optional[Dict[str, Any]] = None, user_data: Optional[Dict[str, Any]] = None,
custom_template_path: Optional[str] = None, custom_template_path: Optional[str] = None,
) -> None: ) -> None:
......
...@@ -41,6 +41,7 @@ pub struct ModelManager { ...@@ -41,6 +41,7 @@ pub struct ModelManager {
chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>, chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>, embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>, tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
prefill_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// These are Mutex because we read and write rarely and equally // These are Mutex because we read and write rarely and equally
cards: Mutex<HashMap<String, ModelDeploymentCard>>, cards: Mutex<HashMap<String, ModelDeploymentCard>>,
...@@ -60,6 +61,7 @@ impl ModelManager { ...@@ -60,6 +61,7 @@ impl ModelManager {
chat_completion_engines: RwLock::new(ModelEngines::default()), chat_completion_engines: RwLock::new(ModelEngines::default()),
embeddings_engines: RwLock::new(ModelEngines::default()), embeddings_engines: RwLock::new(ModelEngines::default()),
tensor_engines: RwLock::new(ModelEngines::default()), tensor_engines: RwLock::new(ModelEngines::default()),
prefill_engines: RwLock::new(ModelEngines::default()),
cards: Mutex::new(HashMap::new()), cards: Mutex::new(HashMap::new()),
kv_choosers: Mutex::new(HashMap::new()), kv_choosers: Mutex::new(HashMap::new()),
} }
...@@ -78,6 +80,7 @@ impl ModelManager { ...@@ -78,6 +80,7 @@ impl ModelManager {
ModelType::Completions => self.completion_engines.read().checksum(model_name), ModelType::Completions => self.completion_engines.read().checksum(model_name),
ModelType::Embedding => self.embeddings_engines.read().checksum(model_name), ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
ModelType::TensorBased => self.tensor_engines.read().checksum(model_name), ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
ModelType::Prefill => self.prefill_engines.read().checksum(model_name),
_ => { _ => {
continue; continue;
} }
...@@ -117,6 +120,7 @@ impl ModelManager { ...@@ -117,6 +120,7 @@ impl ModelManager {
.chain(self.list_completions_models()) .chain(self.list_completions_models())
.chain(self.list_embeddings_models()) .chain(self.list_embeddings_models())
.chain(self.list_tensor_models()) .chain(self.list_tensor_models())
.chain(self.list_prefill_models())
.collect() .collect()
} }
...@@ -136,6 +140,10 @@ impl ModelManager { ...@@ -136,6 +140,10 @@ impl ModelManager {
self.tensor_engines.read().list() self.tensor_engines.read().list()
} }
pub fn list_prefill_models(&self) -> Vec<String> {
self.prefill_engines.read().list()
}
pub fn add_completions_model( pub fn add_completions_model(
&self, &self,
model: &str, model: &str,
...@@ -176,6 +184,16 @@ impl ModelManager { ...@@ -176,6 +184,16 @@ impl ModelManager {
clients.add(model, card_checksum, engine) clients.add(model, card_checksum, engine)
} }
pub fn add_prefill_model(
&self,
model: &str,
card_checksum: &str,
engine: TensorStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.prefill_engines.write();
clients.add(model, card_checksum, engine)
}
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> { pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write(); let mut clients = self.completion_engines.write();
clients.remove(model) clients.remove(model)
...@@ -196,6 +214,11 @@ impl ModelManager { ...@@ -196,6 +214,11 @@ impl ModelManager {
clients.remove(model) clients.remove(model)
} }
pub fn remove_prefill_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.prefill_engines.write();
clients.remove(model)
}
pub fn get_embeddings_engine( pub fn get_embeddings_engine(
&self, &self,
model: &str, model: &str,
...@@ -240,6 +263,17 @@ impl ModelManager { ...@@ -240,6 +263,17 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string())) .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
} }
pub fn get_prefill_engine(
&self,
model: &str,
) -> Result<TensorStreamingEngine, ModelManagerError> {
self.prefill_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
/// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is /// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is
/// deleted. /// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> { pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
......
...@@ -61,6 +61,7 @@ const ALL_MODEL_TYPES: &[ModelType] = &[ ...@@ -61,6 +61,7 @@ const ALL_MODEL_TYPES: &[ModelType] = &[
ModelType::Completions, ModelType::Completions,
ModelType::Embedding, ModelType::Embedding,
ModelType::TensorBased, ModelType::TensorBased,
ModelType::Prefill,
]; ];
impl ModelWatcher { impl ModelWatcher {
...@@ -246,11 +247,13 @@ impl ModelWatcher { ...@@ -246,11 +247,13 @@ impl ModelWatcher {
let completions_model_remove_err = self.manager.remove_completions_model(&model_name); let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name); let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name); let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
let prefill_model_remove_err = self.manager.remove_prefill_model(&model_name);
let mut chat_model_removed = false; let mut chat_model_removed = false;
let mut completions_model_removed = false; let mut completions_model_removed = false;
let mut embeddings_model_removed = false; let mut embeddings_model_removed = false;
let mut tensor_model_removed = false; let mut tensor_model_removed = false;
let mut prefill_model_removed = false;
if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() { if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
chat_model_removed = true; chat_model_removed = true;
...@@ -265,26 +268,32 @@ impl ModelWatcher { ...@@ -265,26 +268,32 @@ impl ModelWatcher {
if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() { if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
tensor_model_removed = true; tensor_model_removed = true;
} }
if prefill_model_remove_err.is_ok() && self.manager.list_prefill_models().is_empty() {
prefill_model_removed = true;
}
if !chat_model_removed if !chat_model_removed
&& !completions_model_removed && !completions_model_removed
&& !embeddings_model_removed && !embeddings_model_removed
&& !tensor_model_removed && !tensor_model_removed
&& !prefill_model_removed
{ {
tracing::debug!( tracing::debug!(
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}", "No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}, prefill_model_removed: {}",
model_name, model_name,
chat_model_removed, chat_model_removed,
completions_model_removed, completions_model_removed,
embeddings_model_removed, embeddings_model_removed,
tensor_model_removed tensor_model_removed,
prefill_model_removed
); );
} else { } else {
for model_type in ALL_MODEL_TYPES { for model_type in ALL_MODEL_TYPES {
if ((chat_model_removed && *model_type == ModelType::Chat) if ((chat_model_removed && *model_type == ModelType::Chat)
|| (completions_model_removed && *model_type == ModelType::Completions) || (completions_model_removed && *model_type == ModelType::Completions)
|| (embeddings_model_removed && *model_type == ModelType::Embedding) || (embeddings_model_removed && *model_type == ModelType::Embedding)
|| (tensor_model_removed && *model_type == ModelType::TensorBased)) || (tensor_model_removed && *model_type == ModelType::TensorBased)
|| (prefill_model_removed && *model_type == ModelType::Prefill))
&& let Some(tx) = &self.model_update_tx && let Some(tx) = &self.model_update_tx
{ {
tx.send(ModelUpdate::Removed(card.clone())).await.ok(); tx.send(ModelUpdate::Removed(card.clone())).await.ok();
...@@ -484,11 +493,27 @@ impl ModelWatcher { ...@@ -484,11 +493,27 @@ impl ModelWatcher {
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
.add_tensor_model(card.name(), checksum, engine)?; .add_tensor_model(card.name(), checksum, engine)?;
} else if card.model_type.supports_prefill() {
// Case 6: Prefill
// Guardrail: Verify model_input is Tokens
if card.model_input != ModelInput::Tokens {
anyhow::bail!(
"Prefill models must use ModelInput::Tokens, got {}",
card.model_input.as_str()
);
}
// This is effectively a guardrail + passthrough for now
// TODO: Build proper prefill pipeline with KV router (track_active_blocks=false)
tracing::info!(
model_name = card.name(),
"Prefill model registered (passthrough, not yet functional)"
);
} else { } else {
// Reject unsupported combinations // Reject unsupported combinations
anyhow::bail!( anyhow::bail!(
"Unsupported model configuration: {} with {} input. Supported combinations: \ "Unsupported model configuration: {} with {} input. Supported combinations: \
Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased", Tokens+(Chat|Completions|Prefill), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
card.model_type, card.model_type,
card.model_input.as_str() card.model_input.as_str()
); );
......
...@@ -36,6 +36,7 @@ bitflags! { ...@@ -36,6 +36,7 @@ bitflags! {
const Completions = 1 << 1; const Completions = 1 << 1;
const Embedding = 1 << 2; const Embedding = 1 << 2;
const TensorBased = 1 << 3; const TensorBased = 1 << 3;
const Prefill = 1 << 4;
} }
} }
...@@ -56,6 +57,9 @@ impl ModelType { ...@@ -56,6 +57,9 @@ impl ModelType {
pub fn supports_tensor(&self) -> bool { pub fn supports_tensor(&self) -> bool {
self.contains(ModelType::TensorBased) self.contains(ModelType::TensorBased)
} }
pub fn supports_prefill(&self) -> bool {
self.contains(ModelType::Prefill)
}
pub fn as_vec(&self) -> Vec<&'static str> { pub fn as_vec(&self) -> Vec<&'static str> {
let mut result = Vec::new(); let mut result = Vec::new();
...@@ -71,6 +75,9 @@ impl ModelType { ...@@ -71,6 +75,9 @@ impl ModelType {
if self.supports_tensor() { if self.supports_tensor() {
result.push("tensor"); result.push("tensor");
} }
if self.supports_prefill() {
result.push("prefill");
}
result result
} }
...@@ -90,6 +97,9 @@ impl ModelType { ...@@ -90,6 +97,9 @@ impl ModelType {
if self.supports_tensor() { if self.supports_tensor() {
result.push(ModelType::TensorBased); result.push(ModelType::TensorBased);
} }
if self.supports_prefill() {
result.push(ModelType::Prefill);
}
result result
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment