Unverified Commit 02c822d6 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Using ai-perf for k8 FT tests. (#3289)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent b5782fcd
This diff is collapsed.
This diff is collapsed.
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Dict, Optional, Pattern
from tests.utils.managed_deployment import DeploymentSpec from tests.utils.managed_deployment import DeploymentSpec
# Worker name mapping for different backends
WORKER_MAP = { WORKER_MAP = {
"vllm": { "vllm": {
"decode": "VllmDecodeWorker", "decode": "VllmDecodeWorker",
...@@ -29,6 +31,51 @@ WORKER_MAP = { ...@@ -29,6 +31,51 @@ WORKER_MAP = {
}, },
} }
# Process ready patterns for recovery detection
WORKER_READY_PATTERNS: Dict[str, Pattern] = {
# Frontend
"Frontend": re.compile(r"added model"),
# vLLM workers
"VllmDecodeWorker": re.compile(
r"VllmWorker for (?P<model_name>.*?) has been initialized"
),
"VllmPrefillWorker": re.compile(
r"VllmWorker for (?P<model_name>.*?) has been initialized"
),
# SGLang workers - look for their specific initialization messages
"decode": re.compile(
r"Model registration succeeded|Decode worker handler initialized|Worker handler initialized"
),
"prefill": re.compile(
r"Model registration succeeded|Prefill worker handler initialized|Worker handler initialized"
),
}
def get_all_worker_types() -> list[str]:
"""Get all worker type names for both vLLM and SGLang."""
worker_types = ["Frontend"]
for backend in WORKER_MAP.values():
worker_types.extend(backend.values())
# Remove duplicates while preserving order
seen = set()
result = []
for x in worker_types:
if x not in seen:
seen.add(x)
result.append(x)
return result
def get_worker_ready_pattern(worker_name: str) -> Optional[Pattern]:
"""Get the ready pattern for a specific worker type."""
return WORKER_READY_PATTERNS.get(worker_name)
def get_backend_workers(backend: str) -> Dict[str, str]:
"""Get worker mapping for a specific backend."""
return WORKER_MAP.get(backend, {})
@dataclass @dataclass
class Load: class Load:
...@@ -36,8 +83,7 @@ class Load: ...@@ -36,8 +83,7 @@ class Load:
requests_per_client: int = 150 requests_per_client: int = 150
input_token_length: int = 100 input_token_length: int = 100
output_token_length: int = 100 output_token_length: int = 100
max_retries: int = 1 max_retries: int = 3 # Increased for fault tolerance
max_request_rate: float = 1
sla: Optional[float] = None sla: Optional[float] = None
......
...@@ -31,7 +31,7 @@ def _clients( ...@@ -31,7 +31,7 @@ def _clients(
input_token_length, input_token_length,
output_token_length, output_token_length,
max_retries, max_retries,
max_request_rate, retry_delay=5, # Default 5 seconds between retries
): ):
procs = [] procs = []
ctx = multiprocessing.get_context("spawn") ctx = multiprocessing.get_context("spawn")
...@@ -49,7 +49,7 @@ def _clients( ...@@ -49,7 +49,7 @@ def _clients(
input_token_length, input_token_length,
output_token_length, output_token_length,
max_retries, max_retries,
max_request_rate, retry_delay,
), ),
) )
) )
...@@ -178,6 +178,5 @@ async def test_fault_scenario( ...@@ -178,6 +178,5 @@ async def test_fault_scenario(
scenario.load.input_token_length, scenario.load.input_token_length,
scenario.load.output_token_length, scenario.load.output_token_length,
scenario.load.max_retries, scenario.load.max_retries,
scenario.load.max_request_rate,
): ):
_inject_failures(scenario.failures, logger, deployment) _inject_failures(scenario.failures, logger, deployment)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment