"lib/bindings/vscode:/vscode.git/clone" did not exist on "7c74764b973d7de6867b87ec25b78f2b843e2733"
Unverified Commit feb49914 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: cleaner prefill + decode discovery (#4927)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 0de75a87
...@@ -138,10 +138,21 @@ impl ModelManager { ...@@ -138,10 +138,21 @@ impl ModelManager {
self.cards.lock().values().cloned().collect() self.cards.lock().values().cloned().collect()
} }
pub fn has_model_any(&self, model: &str) -> bool { /// Check if a decode model (chat or completions) is registered
pub fn has_decode_model(&self, model: &str) -> bool {
self.chat_completion_engines.read().contains(model) self.chat_completion_engines.read().contains(model)
|| self.completion_engines.read().contains(model) || self.completion_engines.read().contains(model)
|| self.prefill_engines.read().contains(model) }
/// Check if a prefill model is registered
pub fn has_prefill_model(&self, model: &str) -> bool {
self.prefill_engines.read().contains(model)
}
/// Check if any model (decode or prefill) is registered.
/// Note: For registration skip-checks, use has_decode_model() or has_prefill_model() instead.
pub fn has_model_any(&self, model: &str) -> bool {
self.has_decode_model(model) || self.has_prefill_model(model)
} }
pub fn model_display_names(&self) -> HashSet<String> { pub fn model_display_names(&self) -> HashSet<String> {
......
...@@ -333,21 +333,21 @@ impl ModelWatcher { ...@@ -333,21 +333,21 @@ impl ModelWatcher {
tracing::debug!(model_name = card.name(), "adding model"); tracing::debug!(model_name = card.name(), "adding model");
self.manager.save_model_card(key, card.clone())?; self.manager.save_model_card(key, card.clone())?;
// Check if we should skip registration: // Skip duplicate registrations based on model type.
// - Skip if a model with this name already exists // Prefill and decode models are tracked separately, so registering one
// - UNLESS this is a prefill model and no prefill model exists yet for this name // doesn't block the other (they can arrive in any order).
let is_new_prefill = card.model_type.supports_prefill() let already_registered = if card.model_type.supports_prefill() {
&& !self self.manager.has_prefill_model(card.name())
.manager } else {
.list_prefill_models() self.manager.has_decode_model(card.name())
.contains(&card.name().to_string()); };
if self.manager.has_model_any(card.name()) && !is_new_prefill { if already_registered {
tracing::debug!( tracing::debug!(
model_name = card.name(), model_name = card.name(),
namespace = endpoint_id.namespace, namespace = endpoint_id.namespace,
model_type = %card.model_type, model_type = %card.model_type,
"New endpoint for existing model, skipping" "Model already registered, skipping"
); );
return Ok(()); return Ok(());
} }
......
...@@ -44,12 +44,13 @@ def get_unique_ports( ...@@ -44,12 +44,13 @@ def get_unique_ports(
num_ports: int = 1, num_ports: int = 1,
store_backend: str = "etcd", store_backend: str = "etcd",
request_plane: str = "nats", request_plane: str = "nats",
registration_order: str = "prefill_first",
) -> list[int]: ) -> list[int]:
"""Generate unique ports for parallel test execution. """Generate unique ports for parallel test execution.
Ports are unique based on: Ports are unique based on:
- Test function name (each test gets a base offset) - Test function name (each test gets a base offset)
- Parametrization value (etcd=0, file=50; nats=0, tcp=25) - Parametrization value (etcd=0, file=50; nats=0, tcp=25; prefill_first=0, decode_first=10)
- Port index (for multi-port tests) - Port index (for multi-port tests)
Args: Args:
...@@ -57,6 +58,7 @@ def get_unique_ports( ...@@ -57,6 +58,7 @@ def get_unique_ports(
num_ports: Number of ports needed (1 for single router, 2 for two routers) num_ports: Number of ports needed (1 for single router, 2 for two routers)
store_backend: Storage backend parameter ("etcd" or "file") store_backend: Storage backend parameter ("etcd" or "file")
request_plane: Request plane parameter ("nats" or "tcp") request_plane: Request plane parameter ("nats" or "tcp")
registration_order: Registration order parameter ("prefill_first" or "decode_first")
Returns: Returns:
List of unique port numbers List of unique port numbers
...@@ -76,13 +78,14 @@ def get_unique_ports( ...@@ -76,13 +78,14 @@ def get_unique_ports(
base_offset = test_offsets.get(test_name, 0) base_offset = test_offsets.get(test_name, 0)
# Parametrization offset (etcd=0, file=50; nats=0, tcp=25) # Parametrization offset (etcd=0, file=50; nats=0, tcp=25; prefill_first=0, decode_first=10)
store_offset = 0 if store_backend == "etcd" else 50 store_offset = 0 if store_backend == "etcd" else 50
plane_offset = 0 if request_plane == "nats" else 25 plane_offset = 0 if request_plane == "nats" else 25
order_offset = 0 if registration_order == "prefill_first" else 10
# Generate ports # Generate ports
ports = [ ports = [
BASE_PORT + base_offset + store_offset + plane_offset + i BASE_PORT + base_offset + store_offset + plane_offset + order_offset + i
for i in range(num_ports) for i in range(num_ports)
] ]
return ports return ports
...@@ -596,15 +599,23 @@ def test_router_decisions(request, runtime_services_session, predownload_tokeniz ...@@ -596,15 +599,23 @@ def test_router_decisions(request, runtime_services_session, predownload_tokeniz
@pytest.mark.parallel @pytest.mark.parallel
@pytest.mark.parametrize("registration_order", ["prefill_first", "decode_first"])
def test_router_decisions_disagg( def test_router_decisions_disagg(
request, runtime_services_session, predownload_tokenizers request, runtime_services_session, predownload_tokenizers, registration_order
): ):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup. """Validate KV cache prefix reuse in disaggregated prefill-decode setup.
Tests that progressive requests with overlapping prefixes are routed to the Tests that progressive requests with overlapping prefixes are routed to the
same prefill worker due to KV cache reuse. same prefill worker due to KV cache reuse.
Parameterized to test both registration orders:
- prefill_first: prefill workers register before decode workers
- decode_first: decode workers register before prefill workers
""" """
logger.info("Starting disaggregated router prefix reuse test") logger.info(
f"Starting disaggregated router prefix reuse test "
f"(registration_order={registration_order})"
)
# Generate shared namespace for prefill and decode workers # Generate shared namespace for prefill and decode workers
namespace_suffix = generate_random_suffix() namespace_suffix = generate_random_suffix()
...@@ -617,8 +628,9 @@ def test_router_decisions_disagg( ...@@ -617,8 +628,9 @@ def test_router_decisions_disagg(
decode_workers = None decode_workers = None
try: try:
# Start prefill workers (4 instances) if registration_order == "prefill_first":
logger.info("Starting 4 prefill mocker instances") # Start prefill workers first
logger.info("Starting 4 prefill mocker instances (first)")
prefill_workers = DisaggMockerProcess( prefill_workers = DisaggMockerProcess(
request, request,
namespace=shared_namespace, namespace=shared_namespace,
...@@ -629,8 +641,20 @@ def test_router_decisions_disagg( ...@@ -629,8 +641,20 @@ def test_router_decisions_disagg(
prefill_workers.__enter__() prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}") logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
# Start decode workers (4 instances) # Then start decode workers
logger.info("Starting 4 decode mocker instances") logger.info("Starting 4 decode mocker instances (second)")
decode_workers = DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="decode",
mocker_args=mocker_args,
num_mockers=4,
)
decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
else:
# Start decode workers first
logger.info("Starting 4 decode mocker instances (first)")
decode_workers = DisaggMockerProcess( decode_workers = DisaggMockerProcess(
request, request,
namespace=shared_namespace, namespace=shared_namespace,
...@@ -641,8 +665,22 @@ def test_router_decisions_disagg( ...@@ -641,8 +665,22 @@ def test_router_decisions_disagg(
decode_workers.__enter__() decode_workers.__enter__()
logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}") logger.info(f"Decode workers using endpoint: {decode_workers.endpoint}")
# Then start prefill workers
logger.info("Starting 4 prefill mocker instances (second)")
prefill_workers = DisaggMockerProcess(
request,
namespace=shared_namespace,
worker_type="prefill",
mocker_args=mocker_args,
num_mockers=4,
)
prefill_workers.__enter__()
logger.info(f"Prefill workers using endpoint: {prefill_workers.endpoint}")
# Get unique port for this test # Get unique port for this test
frontend_port = get_unique_ports(request, num_ports=1)[0] frontend_port = get_unique_ports(
request, num_ports=1, registration_order=registration_order
)[0]
# Run disagg routing test # Run disagg routing test
_test_router_decisions_disagg( _test_router_decisions_disagg(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment