Unverified Commit feb49914 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: cleaner prefill + decode discovery (#4927)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 0de75a87
......@@ -138,10 +138,21 @@ impl ModelManager {
self.cards.lock().values().cloned().collect()
}
pub fn has_model_any(&self, model: &str) -> bool {
/// Check if a decode model (chat or completions) is registered
pub fn has_decode_model(&self, model: &str) -> bool {
self.chat_completion_engines.read().contains(model)
|| self.completion_engines.read().contains(model)
|| self.prefill_engines.read().contains(model)
}
/// Check if a prefill model is registered
pub fn has_prefill_model(&self, model: &str) -> bool {
self.prefill_engines.read().contains(model)
}
/// Check if any model (decode or prefill) is registered.
/// Note: For registration skip-checks, use has_decode_model() or has_prefill_model() instead.
pub fn has_model_any(&self, model: &str) -> bool {
self.has_decode_model(model) || self.has_prefill_model(model)
}
pub fn model_display_names(&self) -> HashSet<String> {
......
......@@ -333,21 +333,21 @@ impl ModelWatcher {
tracing::debug!(model_name = card.name(), "adding model");
self.manager.save_model_card(key, card.clone())?;
// Check if we should skip registration:
// - Skip if a model with this name already exists
// - UNLESS this is a prefill model and no prefill model exists yet for this name
let is_new_prefill = card.model_type.supports_prefill()
&& !self
.manager
.list_prefill_models()
.contains(&card.name().to_string());
// Skip duplicate registrations based on model type.
// Prefill and decode models are tracked separately, so registering one
// doesn't block the other (they can arrive in any order).
let already_registered = if card.model_type.supports_prefill() {
self.manager.has_prefill_model(card.name())
} else {
self.manager.has_decode_model(card.name())
};
if self.manager.has_model_any(card.name()) && !is_new_prefill {
if already_registered {
tracing::debug!(
model_name = card.name(),
namespace = endpoint_id.namespace,
model_type = %card.model_type,
"New endpoint for existing model, skipping"
"Model already registered, skipping"
);
return Ok(());
}
......
......@@ -44,12 +44,13 @@ def get_unique_ports(
num_ports: int = 1,
store_backend: str = "etcd",
request_plane: str = "nats",
registration_order: str = "prefill_first",
) -> list[int]:
"""Generate unique ports for parallel test execution.
Ports are unique based on:
- Test function name (each test gets a base offset)
- Parametrization value (etcd=0, file=50; nats=0, tcp=25)
- Parametrization value (etcd=0, file=50; nats=0, tcp=25; prefill_first=0, decode_first=10)
- Port index (for multi-port tests)
Args:
......@@ -57,6 +58,7 @@ def get_unique_ports(
num_ports: Number of ports needed (1 for single router, 2 for two routers)
store_backend: Storage backend parameter ("etcd" or "file")
request_plane: Request plane parameter ("nats" or "tcp")
registration_order: Registration order parameter ("prefill_first" or "decode_first")
Returns:
List of unique port numbers
......@@ -76,13 +78,14 @@ def get_unique_ports(
base_offset = test_offsets.get(test_name, 0)
# Parametrization offset (etcd=0, file=50; nats=0, tcp=25)
# Parametrization offset (etcd=0, file=50; nats=0, tcp=25; prefill_first=0, decode_first=10)
store_offset = 0 if store_backend == "etcd" else 50
plane_offset = 0 if request_plane == "nats" else 25
order_offset = 0 if registration_order == "prefill_first" else 10
# Generate ports
ports = [
BASE_PORT + base_offset + store_offset + plane_offset + i
BASE_PORT + base_offset + store_offset + plane_offset + order_offset + i
for i in range(num_ports)
]
return ports
......@@ -596,15 +599,23 @@ def test_router_decisions(request, runtime_services_session, predownload_tokeniz
@pytest.mark.parallel
@pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"])
def test_router_decisions_disagg(
request, runtime_services_session, predownload_tokenizers
request, runtime_services_session, predownload_tokenizers, registration_order
):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup.
Tests that progressive requests with overlapping prefixes are routed to the
same prefill worker due to KV cache reuse.
Parameterized to test both registration orders:
- prefill_first: prefill workers register before decode workers
- decode_first: decode workers register before prefill workers
"""
logger.info("Starting disaggregated router prefix reuse test")
logger.info(
f"Starting disaggregated router prefix reuse test "
f"(registration_order={registration_order})"
)
# Generate shared namespace for prefill and decode workers
namespace_suffix = generate_random_suffix()
......@@ -617,32 +628,59 @@ def test_router_decisions_disagg(
decode_workers = None
try:
# Start prefill workers (4 instances)
logger.info("Starting 4 prefill mocker instances")
prefill_workers = DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="prefill",
mocker_args=mocker_args,
num_mockers=4,
)
prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
# Start decode workers (4 instances)
logger.info("Starting 4 decode mocker instances")
decode_workers = DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
)
decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
if registration_order == "prefill_first":
# Start prefill workers first
logger.info("Starting 4 prefill mocker instances (first)")
prefill_workers = DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="prefill",
mocker_args=mocker_args,
num_mockers=4,
)
prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
# Then start decode workers
logger.info("Starting 4 decode mocker instances (second)")
decode_workers = DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
)
decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
else:
# Start decode workers first
logger.info("Starting 4 decode mocker instances (first)")
decode_workers = DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
)
decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
# Then start prefill workers
logger.info("Starting 4 prefill mocker instances (second)")
prefill_workers = DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="prefill",
mocker_args=mocker_args,
num_mockers=4,
)
prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
# Get unique port for this test
frontend_port = get_unique_ports(request, num_ports=1)[0]
frontend_port = get_unique_ports(
request, num_ports=1, registration_order=registration_order
)[0]
# Run disagg routing test
_test_router_decisions_disagg(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment