Unverified Commit 6e241236 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

test: parallel mocker tests in CI (#4493)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent febdc998
...@@ -23,6 +23,7 @@ jobs: ...@@ -23,6 +23,7 @@ jobs:
env: env:
CONTAINER_ID: test_${{ github.run_id }}_${{ github.run_attempt }}_${{ github.job }}_dynamo CONTAINER_ID: test_${{ github.run_id }}_${{ github.run_attempt }}_${{ github.job }}_dynamo
PYTEST_XML_FILE: pytest_test_report.xml PYTEST_XML_FILE: pytest_test_report.xml
PYTEST_PARALLEL_XML_FILE: pytest_parallel.xml
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
...@@ -62,9 +63,21 @@ jobs: ...@@ -62,9 +63,21 @@ jobs:
working-directory: ./deploy working-directory: ./deploy
run: | run: |
docker compose down docker compose down
- name: Run pytest - name: Run pytest (parallel tests with xdist)
env: env:
PYTEST_MARKS: "pre_merge or mypy" PYTEST_MARKS: "pre_merge and parallel"
run: |
docker run -w /workspace \
--name ${{ env.CONTAINER_ID }}_pytest_parallel \
${{ steps.define_image_tag.outputs.image_tag }} \
bash -c "pytest --basetemp=/tmp/pytest-parallel --junitxml=${{ env.PYTEST_PARALLEL_XML_FILE }} -n 4 -m \"${{ env.PYTEST_MARKS }}\""
- name: Copy parallel test report from Container
if: always()
run: |
docker cp ${{ env.CONTAINER_ID }}_pytest_parallel:/workspace/${{ env.PYTEST_PARALLEL_XML_FILE }} . || echo "No parallel test report found"
- name: Run pytest (sequential tests)
env:
PYTEST_MARKS: "(pre_merge and not parallel) or mypy"
run: | run: |
docker run -w /workspace \ docker run -w /workspace \
--name ${{ env.CONTAINER_ID }}_pytest \ --name ${{ env.CONTAINER_ID }}_pytest \
...@@ -82,6 +95,7 @@ jobs: ...@@ -82,6 +95,7 @@ jobs:
if-no-files-found: error if-no-files-found: error
path: | path: |
${{ env.PYTEST_XML_FILE }} ${{ env.PYTEST_XML_FILE }}
${{ env.PYTEST_PARALLEL_XML_FILE }}
event_file: event_file:
name: "Event File" name: "Event File"
......
...@@ -28,6 +28,7 @@ pytest-forked ...@@ -28,6 +28,7 @@ pytest-forked
pytest-md-report pytest-md-report
pytest-mypy pytest-mypy
pytest-timeout pytest-timeout
pytest-xdist
# Triton client to Dynamo gRPC server # Triton client to Dynamo gRPC server
tritonclient[grpc] tritonclient[grpc]
# add types library stub for PyYAML # add types library stub for PyYAML
......
...@@ -175,6 +175,8 @@ filterwarnings = [ ...@@ -175,6 +175,8 @@ filterwarnings = [
"ignore:Support for class-based `config`.*:pydantic.warnings.PydanticDeprecatedSince20", "ignore:Support for class-based `config`.*:pydantic.warnings.PydanticDeprecatedSince20",
"ignore:Using extra keyword arguments on `Field`.*:pydantic.warnings.PydanticDeprecatedSince20", "ignore:Using extra keyword arguments on `Field`.*:pydantic.warnings.PydanticDeprecatedSince20",
"ignore:The `schema` method is deprecated.*:pydantic.warnings.PydanticDeprecatedSince20", "ignore:The `schema` method is deprecated.*:pydantic.warnings.PydanticDeprecatedSince20",
# pytest-benchmark automatically disables when xdist is active, ignore the warning
"ignore:.*Benchmarks are automatically disabled.*:pytest_benchmark.logger.PytestBenchmarkWarning",
] ]
...@@ -182,6 +184,7 @@ filterwarnings = [ ...@@ -182,6 +184,7 @@ filterwarnings = [
asyncio_mode = "auto" asyncio_mode = "auto"
markers = [ markers = [
"pre_merge: marks tests to run before merging", "pre_merge: marks tests to run before merging",
"parallel: marks tests that can run in parallel with pytest-xdist",
"nightly: marks tests to run nightly", "nightly: marks tests to run nightly",
"weekly: marks tests to run weekly", "weekly: marks tests to run weekly",
"gpu_1: marks tests to run on GPU", "gpu_1: marks tests to run on GPU",
......
...@@ -17,8 +17,11 @@ import logging ...@@ -17,8 +17,11 @@ import logging
import os import os
import shutil import shutil
import tempfile import tempfile
from pathlib import Path
from typing import Optional
import pytest import pytest
from filelock import FileLock
from tests.utils.constants import TEST_MODELS from tests.utils.constants import TEST_MODELS
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
...@@ -229,6 +232,131 @@ class NatsServer(ManagedProcess): ...@@ -229,6 +232,131 @@ class NatsServer(ManagedProcess):
) )
class SharedManagedProcess:
"""Base class for ManagedProcess with file-based reference counting for multi-process sharing."""
def __init__(
self,
request,
tmp_path_factory,
resource_name: str,
port: int,
timeout: int = 300,
):
self.request = request
self.port = port
self.timeout = timeout
self.resource_name = resource_name
self._server: Optional[ManagedProcess] = None
self._owns_process = False
root_tmp = Path(tempfile.gettempdir()) / "pytest_ref_counting"
root_tmp.mkdir(parents=True, exist_ok=True)
self.ref_file = root_tmp / f"pytest_{resource_name}_{port}_ref_count"
self.lock_file = str(self.ref_file) + ".lock"
def _create_server(self) -> ManagedProcess:
"""Create the underlying server instance. Must be implemented by subclasses."""
raise NotImplementedError
def _read_ref_count(self) -> int:
"""Read current reference count."""
if self.ref_file.exists():
try:
return int(self.ref_file.read_text().strip())
except (ValueError, IOError):
return 0
return 0
def _write_ref_count(self, count: int):
"""Write reference count atomically."""
self.ref_file.write_text(str(count))
def _increment_ref_count(self) -> int:
"""Increment reference count and return new count."""
count = self._read_ref_count()
count += 1
self._write_ref_count(count)
return count
def _decrement_ref_count(self) -> int:
"""Decrement reference count and return new count."""
count = self._read_ref_count()
count = max(0, count - 1)
self._write_ref_count(count)
return count
def __enter__(self):
with FileLock(self.lock_file):
ref_count = self._increment_ref_count()
if ref_count == 1:
# First reference - start the process
self._server = self._create_server()
self._server.__enter__()
self._owns_process = True
logging.info(f"[{self.resource_name}] Started process (ref_count=1)")
else:
# Process already running, just track reference
self._owns_process = False
logging.info(
f"[{self.resource_name}] Reusing existing process (ref_count={ref_count})"
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
with FileLock(self.lock_file):
ref_count = self._decrement_ref_count()
if ref_count == 0 and self._owns_process:
# Last reference - stop the process
if self._server:
self._server.__exit__(exc_type, exc_val, exc_tb)
logging.info(f"[{self.resource_name}] Stopped process (ref_count=0)")
elif ref_count == 0:
# Last reference but we don't own it - shouldn't happen, but clean up ref file
if self.ref_file.exists():
self.ref_file.unlink()
logging.warning(
f"[{self.resource_name}] Ref count reached 0 but we don't own process"
)
else:
logging.info(
f"[{self.resource_name}] Released reference (ref_count={ref_count})"
)
class SharedEtcdServer(SharedManagedProcess):
"""EtcdServer with file-based reference counting for multi-process sharing."""
def __init__(self, request, tmp_path_factory, port=2379, timeout=300):
super().__init__(request, tmp_path_factory, "etcd", port, timeout)
# Create a log directory for session-scoped servers
self._log_dir = tempfile.mkdtemp(prefix=f"pytest_{self.resource_name}_logs_")
def _create_server(self) -> ManagedProcess:
"""Create EtcdServer instance."""
server = EtcdServer(self.request, port=self.port, timeout=self.timeout)
# Override log_dir since request.node.name is empty in session scope
server.log_dir = self._log_dir
return server
class SharedNatsServer(SharedManagedProcess):
"""NatsServer with file-based reference counting for multi-process sharing."""
def __init__(self, request, tmp_path_factory, port=4222, timeout=300):
super().__init__(request, tmp_path_factory, "nats", port, timeout)
# Create a log directory for session-scoped servers
self._log_dir = tempfile.mkdtemp(prefix=f"pytest_{self.resource_name}_logs_")
def _create_server(self) -> ManagedProcess:
"""Create NatsServer instance."""
server = NatsServer(self.request, port=self.port, timeout=self.timeout)
# Override log_dir since request.node.name is empty in session scope
server.log_dir = self._log_dir
return server
@pytest.fixture() @pytest.fixture()
def runtime_services(request): def runtime_services(request):
with NatsServer(request) as nats_process: with NatsServer(request) as nats_process:
...@@ -236,6 +364,20 @@ def runtime_services(request): ...@@ -236,6 +364,20 @@ def runtime_services(request):
yield nats_process, etcd_process yield nats_process, etcd_process
@pytest.fixture(scope="session")
def runtime_services_session(request, tmp_path_factory):
"""Session-scoped fixture that provides shared NATS and etcd instances for all tests.
Uses file-based reference counting to coordinate between pytest-xdist worker processes.
Only the first worker starts services, and only the last worker tears them down.
Test isolation is achieved through unique namespaces (test-namespace-{random-suffix}).
"""
with SharedNatsServer(request, tmp_path_factory) as nats:
with SharedEtcdServer(request, tmp_path_factory) as etcd:
yield nats, etcd
@pytest.fixture @pytest.fixture
def file_storage_backend(): def file_storage_backend():
"""Fixture that sets up and tears down file storage backend. """Fixture that sets up and tears down file storage backend.
......
...@@ -6,6 +6,7 @@ import json ...@@ -6,6 +6,7 @@ import json
import logging import logging
import random import random
import string import string
import time
from typing import Any, Optional from typing import Any, Optional
import aiohttp import aiohttp
...@@ -29,7 +30,12 @@ class KVRouterProcess(ManagedProcess): ...@@ -29,7 +30,12 @@ class KVRouterProcess(ManagedProcess):
"""Manages the KV router process using dynamo.frontend""" """Manages the KV router process using dynamo.frontend"""
def __init__( def __init__(
self, request, block_size: int, frontend_port: int, store_backend: str = "etcd" self,
request,
block_size: int,
frontend_port: int,
namespace: str,
store_backend: str = "etcd",
): ):
command = [ command = [
"python3", "python3",
...@@ -43,6 +49,8 @@ class KVRouterProcess(ManagedProcess): ...@@ -43,6 +49,8 @@ class KVRouterProcess(ManagedProcess):
str(frontend_port), str(frontend_port),
"--store-kv", "--store-kv",
store_backend, store_backend,
"--namespace",
namespace,
] ]
super().__init__( super().__init__(
...@@ -498,17 +506,15 @@ def _test_router_basic( ...@@ -498,17 +506,15 @@ def _test_router_basic(
frontend_port: int, frontend_port: int,
test_payload: dict, test_payload: dict,
num_requests: int, num_requests: int,
wait_for_frontend: bool = False, frontend_timeout: int = 120,
frontend_timeout: int = 180,
store_backend: str = "etcd", store_backend: str = "etcd",
): ):
"""Basic router test: start router, wait for workers (optional) and send concurrent requests via HTTP frontend. """Basic router test: start router, wait for workers and send concurrent requests via HTTP frontend.
Assumes engine_workers are already initialized. This function manages router lifecycle. Assumes engine_workers are already initialized. This function manages router lifecycle.
This is a shared test implementation for both mocker and vLLM workers. This is a shared test implementation for both mocker and vLLM workers.
The key difference is that vLLM workers need time to load models and register, Always waits for workers to be properly registered before sending requests to avoid flakiness.
so they require wait_for_frontend=True.
Args: Args:
engine_workers: Backend workers (mocker/vllm) already initialized with __enter__() engine_workers: Backend workers (mocker/vllm) already initialized with __enter__()
...@@ -517,8 +523,7 @@ def _test_router_basic( ...@@ -517,8 +523,7 @@ def _test_router_basic(
frontend_port: Port to start the frontend HTTP server on frontend_port: Port to start the frontend HTTP server on
test_payload: Test payload to send to /v1/chat/completions test_payload: Test payload to send to /v1/chat/completions
num_requests: Number of concurrent requests to send num_requests: Number of concurrent requests to send
wait_for_frontend: If True, poll /v1/models and /v1/chat/completions until ready frontend_timeout: Timeout for frontend readiness check (default: 120s)
frontend_timeout: Timeout for frontend readiness check (only used if wait_for_frontend=True)
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd". store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
Raises: Raises:
...@@ -528,21 +533,22 @@ def _test_router_basic( ...@@ -528,21 +533,22 @@ def _test_router_basic(
try: try:
# Start KV router frontend # Start KV router frontend
logger.info(f"Starting KV router frontend on port {frontend_port}") logger.info(f"Starting KV router frontend on port {frontend_port}")
kv_router = KVRouterProcess(request, block_size, frontend_port, store_backend) kv_router = KVRouterProcess(
request, block_size, frontend_port, engine_workers.namespace, store_backend
)
kv_router.__enter__() kv_router.__enter__()
frontend_url = f"http://localhost:{frontend_port}" frontend_url = f"http://localhost:{frontend_port}"
# Wait for workers to register with frontend if needed (vLLM requires this) # Always wait for workers to register with frontend to avoid flakiness
if wait_for_frontend: logger.info("Waiting for workers to register with frontend...")
logger.info("Waiting for workers to register with frontend...") asyncio.run(
asyncio.run( wait_for_frontend_ready(
wait_for_frontend_ready( frontend_url=frontend_url,
frontend_url=frontend_url, expected_num_workers=engine_workers.num_workers,
expected_num_workers=engine_workers.num_workers, timeout=frontend_timeout,
timeout=frontend_timeout,
)
) )
)
# Send concurrent requests to the frontend # Send concurrent requests to the frontend
logger.info(f"Sending {num_requests} concurrent requests to frontend...") logger.info(f"Sending {num_requests} concurrent requests to frontend...")
...@@ -598,12 +604,36 @@ def _test_router_two_routers( ...@@ -598,12 +604,36 @@ def _test_router_two_routers(
try: try:
# Start two KV routers on different ports # Start two KV routers on different ports
for port in router_ports: for i, port in enumerate(router_ports):
logger.info(f"Starting KV router frontend on port {port}") logger.info(f"Starting KV router frontend on port {port}")
kv_router = KVRouterProcess(request, block_size, port, store_backend) kv_router = KVRouterProcess(
request, block_size, port, engine_workers.namespace, store_backend
)
kv_router.__enter__() kv_router.__enter__()
kv_routers.append(kv_router) kv_routers.append(kv_router)
# Add delay between routers for file backend to ensure first router's
# registration is visible before second router starts its cleanup
if i == 0 and store_backend == "file":
logger.info(
"Waiting 0.5s for first router to fully register (file backend)"
)
time.sleep(0.5)
# Wait for workers to be ready on both routers
logger.info("Waiting for workers to register with both routers...")
for i, port in enumerate(router_ports):
frontend_url = f"http://localhost:{port}"
logger.info(f"Waiting for router {i} on port {port} to discover workers...")
asyncio.run(
wait_for_frontend_ready(
frontend_url=frontend_url,
expected_num_workers=engine_workers.num_workers,
timeout=120,
)
)
logger.info("Both routers have discovered workers")
# Build URLs for both routers # Build URLs for both routers
router_urls = [ router_urls = [
f"http://localhost:{port}/v1/chat/completions" for port in router_ports f"http://localhost:{port}/v1/chat/completions" for port in router_ports
...@@ -874,7 +904,9 @@ def _test_router_query_instance_id( ...@@ -874,7 +904,9 @@ def _test_router_query_instance_id(
try: try:
# Start KV router (frontend) # Start KV router (frontend)
logger.info(f"Starting KV router frontend on port {frontend_port}") logger.info(f"Starting KV router frontend on port {frontend_port}")
kv_router = KVRouterProcess(request, block_size, frontend_port, store_backend) kv_router = KVRouterProcess(
request, block_size, frontend_port, engine_workers.namespace, store_backend
)
kv_router.__enter__() kv_router.__enter__()
url = f"http://localhost:{frontend_port}/v1/chat/completions" url = f"http://localhost:{frontend_port}/v1/chat/completions"
......
...@@ -29,14 +29,50 @@ logger = logging.getLogger(__name__) ...@@ -29,14 +29,50 @@ logger = logging.getLogger(__name__)
MODEL_NAME = ROUTER_MODEL_NAME MODEL_NAME = ROUTER_MODEL_NAME
NUM_MOCKERS = 2 NUM_MOCKERS = 2
SPEEDUP_RATIO = 10.0 SPEEDUP_RATIO = 10.0
PORTS = [ BASE_PORT = 9100 # Base port for all tests (high port to avoid conflicts)
8011,
8022,
] # Frontend ports: use PORTS[0] for single router, PORTS for multi-router
NUM_REQUESTS = 100 NUM_REQUESTS = 100
BLOCK_SIZE = 16 BLOCK_SIZE = 16
def get_unique_ports(
request, num_ports: int = 1, store_backend: str = "etcd"
) -> list[int]:
"""Generate unique ports for parallel test execution.
Ports are unique based on:
- Test function name (each test gets a base offset)
- Parametrization value (etcd=0, file=50)
- Port index (for multi-port tests)
Args:
request: Pytest request fixture
num_ports: Number of ports needed (1 for single router, 2 for two routers)
store_backend: Storage backend parameter ("etcd" or "file")
Returns:
List of unique port numbers
"""
# Get test name without parametrization suffix
test_name = request.node.name.split("[")[0]
# Base offsets per test function (ensures each test gets unique range)
test_offsets = {
"test_mocker_kv_router": 0,
"test_mocker_two_kv_router": 100,
"test_mocker_kv_router_overload_503": 200,
"test_query_instance_id_returns_worker_and_tokens": 300,
}
base_offset = test_offsets.get(test_name, 0)
# Parametrization offset (etcd=0, file=50)
param_offset = 0 if store_backend == "etcd" else 50
# Generate ports
ports = [BASE_PORT + base_offset + param_offset + i for i in range(num_ports)]
return ports
# Shared test payload for all tests # Shared test payload for all tests
TEST_PAYLOAD: Dict[str, Any] = { TEST_PAYLOAD: Dict[str, Any] = {
"model": MODEL_NAME, "model": MODEL_NAME,
...@@ -148,8 +184,9 @@ class MockerProcess: ...@@ -148,8 +184,9 @@ class MockerProcess:
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.parallel
@pytest.mark.model(MODEL_NAME) @pytest.mark.model(MODEL_NAME)
def test_mocker_kv_router(request, runtime_services, predownload_tokenizers): def test_mocker_kv_router(request, runtime_services_session, predownload_tokenizers):
""" """
Test KV router with multiple mocker engine instances. Test KV router with multiple mocker engine instances.
This test doesn't require GPUs and runs quickly for pre-merge validation. This test doesn't require GPUs and runs quickly for pre-merge validation.
...@@ -170,15 +207,17 @@ def test_mocker_kv_router(request, runtime_services, predownload_tokenizers): ...@@ -170,15 +207,17 @@ def test_mocker_kv_router(request, runtime_services, predownload_tokenizers):
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__() mockers.__enter__()
# Run basic router test (starts router internally, mocker workers don't need frontend readiness check) # Get unique port for this test
frontend_port = get_unique_ports(request, num_ports=1)[0]
# Run basic router test (starts router internally and waits for workers to be ready)
_test_router_basic( _test_router_basic(
engine_workers=mockers, engine_workers=mockers,
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
request=request, request=request,
frontend_port=PORTS[0], frontend_port=frontend_port,
test_payload=TEST_PAYLOAD, test_payload=TEST_PAYLOAD,
num_requests=NUM_REQUESTS, num_requests=NUM_REQUESTS,
wait_for_frontend=False, # Mocker workers are fast, no need to wait
) )
finally: finally:
...@@ -187,11 +226,12 @@ def test_mocker_kv_router(request, runtime_services, predownload_tokenizers): ...@@ -187,11 +226,12 @@ def test_mocker_kv_router(request, runtime_services, predownload_tokenizers):
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.parallel
@pytest.mark.model(MODEL_NAME) @pytest.mark.model(MODEL_NAME)
@pytest.mark.parametrize("store_backend", ["etcd", "file"]) @pytest.mark.parametrize("store_backend", ["etcd", "file"])
def test_mocker_two_kv_router( def test_mocker_two_kv_router(
request, request,
runtime_services, runtime_services_session,
predownload_tokenizers, predownload_tokenizers,
file_storage_backend, file_storage_backend,
store_backend, store_backend,
...@@ -222,12 +262,17 @@ def test_mocker_two_kv_router( ...@@ -222,12 +262,17 @@ def test_mocker_two_kv_router(
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__() mockers.__enter__()
# Get unique ports for this test (2 ports for two routers)
router_ports = get_unique_ports(
request, num_ports=2, store_backend=store_backend
)
# Run two-router test (starts KV routers internally and manages their lifecycle) # Run two-router test (starts KV routers internally and manages their lifecycle)
_test_router_two_routers( _test_router_two_routers(
engine_workers=mockers, engine_workers=mockers,
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
request=request, request=request,
router_ports=PORTS, router_ports=router_ports,
test_payload=TEST_PAYLOAD, test_payload=TEST_PAYLOAD,
num_requests=NUM_REQUESTS, num_requests=NUM_REQUESTS,
store_backend=store_backend, store_backend=store_backend,
...@@ -239,10 +284,11 @@ def test_mocker_two_kv_router( ...@@ -239,10 +284,11 @@ def test_mocker_two_kv_router(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.parallel
@pytest.mark.model(MODEL_NAME) @pytest.mark.model(MODEL_NAME)
@pytest.mark.skip(reason="Flaky, temporarily disabled") @pytest.mark.skip(reason="Flaky, temporarily disabled")
def test_mocker_kv_router_overload_503( def test_mocker_kv_router_overload_503(
request, runtime_services, predownload_tokenizers request, runtime_services_session, predownload_tokenizers
): ):
"""Test that KV router returns 503 when mocker workers are overloaded.""" """Test that KV router returns 503 when mocker workers are overloaded."""
logger.info("Starting mocker KV router overload test for 503 status") logger.info("Starting mocker KV router overload test for 503 status")
...@@ -260,8 +306,10 @@ def test_mocker_kv_router_overload_503( ...@@ -260,8 +306,10 @@ def test_mocker_kv_router_overload_503(
logger.info(f"Mocker using endpoint: {mockers.endpoint}") logger.info(f"Mocker using endpoint: {mockers.endpoint}")
mockers.__enter__() mockers.__enter__()
# Get unique port for this test
frontend_port = get_unique_ports(request, num_ports=1)[0]
# Run overload 503 test # Run overload 503 test
frontend_port = PORTS[0] + 10 # Use different port to avoid conflicts
_test_router_overload_503( _test_router_overload_503(
engine_workers=mockers, engine_workers=mockers,
block_size=4, # Match the mocker's block size block_size=4, # Match the mocker's block size
...@@ -277,8 +325,11 @@ def test_mocker_kv_router_overload_503( ...@@ -277,8 +325,11 @@ def test_mocker_kv_router_overload_503(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.parallel
@pytest.mark.model(MODEL_NAME) @pytest.mark.model(MODEL_NAME)
def test_kv_push_router_bindings(request, runtime_services, predownload_tokenizers): def test_kv_push_router_bindings(
request, runtime_services_session, predownload_tokenizers
):
"""Test KvPushRouter Python bindings with mocker engines.""" """Test KvPushRouter Python bindings with mocker engines."""
logger.info("Starting KvPushRouter bindings test") logger.info("Starting KvPushRouter bindings test")
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE} mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
...@@ -313,11 +364,12 @@ def test_kv_push_router_bindings(request, runtime_services, predownload_tokenize ...@@ -313,11 +364,12 @@ def test_kv_push_router_bindings(request, runtime_services, predownload_tokenize
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.parallel
@pytest.mark.model(MODEL_NAME) @pytest.mark.model(MODEL_NAME)
@pytest.mark.parametrize("store_backend", ["etcd", "file"]) @pytest.mark.parametrize("store_backend", ["etcd", "file"])
def test_indexers_sync( def test_indexers_sync(
request, request,
runtime_services, runtime_services_session,
predownload_tokenizers, predownload_tokenizers,
file_storage_backend, file_storage_backend,
store_backend, store_backend,
...@@ -364,9 +416,10 @@ def test_indexers_sync( ...@@ -364,9 +416,10 @@ def test_indexers_sync(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.parallel
@pytest.mark.model(MODEL_NAME) @pytest.mark.model(MODEL_NAME)
def test_query_instance_id_returns_worker_and_tokens( def test_query_instance_id_returns_worker_and_tokens(
request, runtime_services, predownload_tokenizers request, runtime_services_session, predownload_tokenizers
): ):
"""Test query_instance_id annotation with mocker engines.""" """Test query_instance_id annotation with mocker engines."""
logger.info("Starting KV router query_instance_id annotation test") logger.info("Starting KV router query_instance_id annotation test")
...@@ -382,8 +435,10 @@ def test_query_instance_id_returns_worker_and_tokens( ...@@ -382,8 +435,10 @@ def test_query_instance_id_returns_worker_and_tokens(
logger.info(f"All mockers using endpoint: {mockers.endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__() mockers.__enter__()
# Get unique port for this test
frontend_port = get_unique_ports(request, num_ports=1)[0]
# Run query_instance_id annotation test # Run query_instance_id annotation test
frontend_port = PORTS[0] + 30 # Use unique port to avoid conflicts
_test_router_query_instance_id( _test_router_query_instance_id(
engine_workers=mockers, engine_workers=mockers,
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
...@@ -398,8 +453,9 @@ def test_query_instance_id_returns_worker_and_tokens( ...@@ -398,8 +453,9 @@ def test_query_instance_id_returns_worker_and_tokens(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.parallel
@pytest.mark.model(MODEL_NAME) @pytest.mark.model(MODEL_NAME)
def test_router_decisions(request, runtime_services, predownload_tokenizers): def test_router_decisions(request, runtime_services_session, predownload_tokenizers):
"""Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes.""" """Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes."""
# runtime_services starts etcd and nats # runtime_services starts etcd and nats
......
...@@ -302,7 +302,7 @@ def test_vllm_kv_router_basic(request, runtime_services, predownload_tokenizers) ...@@ -302,7 +302,7 @@ def test_vllm_kv_router_basic(request, runtime_services, predownload_tokenizers)
logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}") logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}")
vllm_workers.__enter__() vllm_workers.__enter__()
# Run basic router test (starts router internally, vLLM workers need frontend readiness check) # Run basic router test (starts router internally and waits for workers to be ready)
_test_router_basic( _test_router_basic(
engine_workers=vllm_workers, engine_workers=vllm_workers,
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
...@@ -310,7 +310,6 @@ def test_vllm_kv_router_basic(request, runtime_services, predownload_tokenizers) ...@@ -310,7 +310,6 @@ def test_vllm_kv_router_basic(request, runtime_services, predownload_tokenizers)
frontend_port=PORTS[0], frontend_port=PORTS[0],
test_payload=TEST_PAYLOAD, test_payload=TEST_PAYLOAD,
num_requests=NUM_REQUESTS, num_requests=NUM_REQUESTS,
wait_for_frontend=True, # vLLM workers need time to load models
frontend_timeout=180, # 3 minutes should be plenty for TinyLlama frontend_timeout=180, # 3 minutes should be plenty for TinyLlama
store_backend="etcd", # Explicit for clarity store_backend="etcd", # Explicit for clarity
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment