Unverified Commit 4b7a806c authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: add prefill workers to discovery (#3709)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent ca3daddc
......@@ -217,6 +217,20 @@ fn register_llm<'p>(
user_data: Option<&Bound<'p, PyDict>>,
custom_template_path: Option<&str>,
) -> PyResult<Bound<'p, PyAny>> {
// Validate Prefill model type requirements
if model_type.inner == llm_rs::model_type::ModelType::Prefill {
if !matches!(model_input, ModelInput::Tokens) {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"ModelType::Prefill requires model_input to be ModelInput::Tokens",
));
}
if migration_limit != 0 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"ModelType::Prefill requires migration_limit to be 0",
));
}
}
let model_input = match model_input {
ModelInput::Text => llm_rs::model_type::ModelInput::Text,
ModelInput::Tokens => llm_rs::model_type::ModelInput::Tokens,
......@@ -370,6 +384,10 @@ impl ModelType {
const TensorBased: Self = ModelType {
inner: llm_rs::model_type::ModelType::TensorBased,
};
#[classattr]
const Prefill: Self = ModelType {
inner: llm_rs::model_type::ModelType::Prefill,
};
fn __or__(&self, other: &Self) -> Self {
ModelType {
......
......@@ -860,7 +860,12 @@ class ModelInput:
...
class ModelType:
"""What type of request this model needs: Chat, Completions, Embedding or Tensor"""
"""What type of request this model needs: Chat, Completions, Embedding, Tensor or Prefill"""
Chat: ModelType
Completions: ModelType
Embedding: ModelType
TensorBased: ModelType
Prefill: ModelType
...
class RouterMode:
......@@ -883,8 +888,9 @@ async def register_llm(
model_name: Optional[str] = None,
context_length: Optional[int] = None,
kv_cache_block_size: Optional[int] = None,
migration_limit: int = 0,
router_mode: Optional[RouterMode] = None,
migration_limit: int = 0,
runtime_config: Optional[ModelRuntimeConfig] = None,
user_data: Optional[Dict[str, Any]] = None,
custom_template_path: Optional[str] = None,
) -> None:
......
......@@ -41,6 +41,7 @@ pub struct ModelManager {
chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
prefill_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// These are Mutex because we read and write rarely and equally
cards: Mutex<HashMap<String, ModelDeploymentCard>>,
......@@ -60,6 +61,7 @@ impl ModelManager {
chat_completion_engines: RwLock::new(ModelEngines::default()),
embeddings_engines: RwLock::new(ModelEngines::default()),
tensor_engines: RwLock::new(ModelEngines::default()),
prefill_engines: RwLock::new(ModelEngines::default()),
cards: Mutex::new(HashMap::new()),
kv_choosers: Mutex::new(HashMap::new()),
}
......@@ -78,6 +80,7 @@ impl ModelManager {
ModelType::Completions => self.completion_engines.read().checksum(model_name),
ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
ModelType::Prefill => self.prefill_engines.read().checksum(model_name),
_ => {
continue;
}
......@@ -117,6 +120,7 @@ impl ModelManager {
.chain(self.list_completions_models())
.chain(self.list_embeddings_models())
.chain(self.list_tensor_models())
.chain(self.list_prefill_models())
.collect()
}
......@@ -136,6 +140,10 @@ impl ModelManager {
self.tensor_engines.read().list()
}
pub fn list_prefill_models(&self) -> Vec<String> {
self.prefill_engines.read().list()
}
pub fn add_completions_model(
&self,
model: &str,
......@@ -176,6 +184,16 @@ impl ModelManager {
clients.add(model, card_checksum, engine)
}
pub fn add_prefill_model(
&self,
model: &str,
card_checksum: &str,
engine: TensorStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.prefill_engines.write();
clients.add(model, card_checksum, engine)
}
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write();
clients.remove(model)
......@@ -196,6 +214,11 @@ impl ModelManager {
clients.remove(model)
}
pub fn remove_prefill_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.prefill_engines.write();
clients.remove(model)
}
pub fn get_embeddings_engine(
&self,
model: &str,
......@@ -240,6 +263,17 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn get_prefill_engine(
&self,
model: &str,
) -> Result<TensorStreamingEngine, ModelManagerError> {
self.prefill_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
/// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is
/// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
......
......@@ -61,6 +61,7 @@ const ALL_MODEL_TYPES: &[ModelType] = &[
ModelType::Completions,
ModelType::Embedding,
ModelType::TensorBased,
ModelType::Prefill,
];
impl ModelWatcher {
......@@ -246,11 +247,13 @@ impl ModelWatcher {
let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
let prefill_model_remove_err = self.manager.remove_prefill_model(&model_name);
let mut chat_model_removed = false;
let mut completions_model_removed = false;
let mut embeddings_model_removed = false;
let mut tensor_model_removed = false;
let mut prefill_model_removed = false;
if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
chat_model_removed = true;
......@@ -265,26 +268,32 @@ impl ModelWatcher {
if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
tensor_model_removed = true;
}
if prefill_model_remove_err.is_ok() && self.manager.list_prefill_models().is_empty() {
prefill_model_removed = true;
}
if !chat_model_removed
&& !completions_model_removed
&& !embeddings_model_removed
&& !tensor_model_removed
&& !prefill_model_removed
{
tracing::debug!(
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}",
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, tensor_model_removed: {}, prefill_model_removed: {}",
model_name,
chat_model_removed,
completions_model_removed,
embeddings_model_removed,
tensor_model_removed
tensor_model_removed,
prefill_model_removed
);
} else {
for model_type in ALL_MODEL_TYPES {
if ((chat_model_removed && *model_type == ModelType::Chat)
|| (completions_model_removed && *model_type == ModelType::Completions)
|| (embeddings_model_removed && *model_type == ModelType::Embedding)
|| (tensor_model_removed && *model_type == ModelType::TensorBased))
|| (tensor_model_removed && *model_type == ModelType::TensorBased)
|| (prefill_model_removed && *model_type == ModelType::Prefill))
&& let Some(tx) = &self.model_update_tx
{
tx.send(ModelUpdate::Removed(card.clone())).await.ok();
......@@ -484,11 +493,27 @@ impl ModelWatcher {
let engine = Arc::new(push_router);
self.manager
.add_tensor_model(card.name(), checksum, engine)?;
} else if card.model_type.supports_prefill() {
// Case 6: Prefill
// Guardrail: Verify model_input is Tokens
if card.model_input != ModelInput::Tokens {
anyhow::bail!(
"Prefill models must use ModelInput::Tokens, got {}",
card.model_input.as_str()
);
}
// This is effectively a guardrail + passthrough for now
// TODO: Build proper prefill pipeline with KV router (track_active_blocks=false)
tracing::info!(
model_name = card.name(),
"Prefill model registered (passthrough, not yet functional)"
);
} else {
// Reject unsupported combinations
anyhow::bail!(
"Unsupported model configuration: {} with {} input. Supported combinations: \
Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
Tokens+(Chat|Completions|Prefill), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
card.model_type,
card.model_input.as_str()
);
......
......@@ -36,6 +36,7 @@ bitflags! {
const Completions = 1 << 1;
const Embedding = 1 << 2;
const TensorBased = 1 << 3;
const Prefill = 1 << 4;
}
}
......@@ -56,6 +57,9 @@ impl ModelType {
pub fn supports_tensor(&self) -> bool {
self.contains(ModelType::TensorBased)
}
pub fn supports_prefill(&self) -> bool {
self.contains(ModelType::Prefill)
}
pub fn as_vec(&self) -> Vec<&'static str> {
let mut result = Vec::new();
......@@ -71,6 +75,9 @@ impl ModelType {
if self.supports_tensor() {
result.push("tensor");
}
if self.supports_prefill() {
result.push("prefill");
}
result
}
......@@ -90,6 +97,9 @@ impl ModelType {
if self.supports_tensor() {
result.push(ModelType::TensorBased);
}
if self.supports_prefill() {
result.push(ModelType::Prefill);
}
result
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment