Unverified Commit f72dc01d authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

test: add test for pre-deployment script (#2857)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent 23ede721
This diff is collapsed.
......@@ -15,11 +15,14 @@
import logging
import re
from typing import Literal, Optional, cast
from typing import Literal, Optional, Protocol
from pydantic import BaseModel
from utils.defaults import DEFAULT_MODEL_NAME, DYNAMO_RUN_DEFAULT_PORT
from benchmarks.profiler.utils.defaults import (
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
)
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
logger = logging.getLogger(__name__)
......@@ -34,27 +37,30 @@ logger.addHandler(console_handler)
class Container(BaseModel):
args: list[str] = []
args: Optional[list[str]] = None
model_config = {"extra": "allow"}
class PodSpec(BaseModel):
mainContainer: Container
mainContainer: Optional[Container] = None
model_config = {"extra": "allow"}
class ServiceResources(BaseModel):
requests: dict[str, str]
requests: Optional[dict[str, str]] = None
limits: Optional[dict[str, str]] = None
class Service(BaseModel):
replicas: int
resources: ServiceResources
extraPodSpec: PodSpec
replicas: Optional[int] = None
resources: Optional[ServiceResources] = None
extraPodSpec: Optional[PodSpec] = None
model_config = {"extra": "allow"}
class Services(BaseModel):
Frontend: Service
__root__: dict[str, Service]
model_config = {"extra": "allow"}
class Spec(BaseModel):
......@@ -68,15 +74,19 @@ class Metadata(BaseModel):
class Config(BaseModel):
metadata: Metadata
spec: Spec
model_config = {"extra": "allow"}
def break_arguments(args: list[str]) -> list[str]:
ans = []
def break_arguments(args: list[str] | None) -> list[str]:
ans: list[str] = []
if args is None:
return ans
if isinstance(args, str):
ans = re.split(r"[ =]", args)
else:
for arg in args:
ans.extend(arg.split(" "))
if arg is not None:
ans.extend(arg.split(" "))
return ans
......@@ -122,6 +132,28 @@ def find_arg_index(args: list[str]) -> int:
return idx
class ConfigModifierProtocol(Protocol):
@classmethod
def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
...
@classmethod
def set_config_tp_size(cls, config: dict, tp_size: int) -> dict:
...
@classmethod
def get_model_name(cls, config: dict) -> str:
...
@classmethod
def get_port(cls, config: dict) -> int:
...
@classmethod
def get_kv_cache_size_from_dynamo_log(cls, dynamo_log_fn: str) -> int:
...
class VllmV1ConfigModifier:
@classmethod
def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
......@@ -145,9 +177,17 @@ class VllmV1ConfigModifier:
WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
]
args = cfg.spec.services[
worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args
]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in worker service"
)
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
......@@ -160,9 +200,7 @@ class VllmV1ConfigModifier:
if "--no-enable-prefix-caching" not in args:
args = append_argument(args, "--no-enable-prefix-caching")
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args = join_arguments(args)
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
elif target == "decode":
# delete prefill worker
......@@ -170,9 +208,17 @@ class VllmV1ConfigModifier:
WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
]
args = cfg.spec.services[
worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args
]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in worker service"
)
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
......@@ -182,9 +228,7 @@ class VllmV1ConfigModifier:
if "--no-enable-prefix-caching" in args:
args.remove("--no-enable-prefix-caching")
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args = join_arguments(args)
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
# set num workers to 1
decode_worker_config = cfg.spec.services[
......@@ -198,27 +242,30 @@ class VllmV1ConfigModifier:
def set_config_tp_size(cls, config: dict, tp_size: int):
cfg = Config.model_validate(config)
cfg.spec.services[
worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].resources.requests["gpu"] = str(tp_size)
]
# Ensure resources exists
if worker_service.resources is None:
worker_service.resources = ServiceResources()
# Ensure requests exists
if worker_service.resources.requests is None:
worker_service.resources.requests = {}
worker_service.resources.requests["gpu"] = str(tp_size)
# Update limits if they exist
if worker_service.resources.limits is not None:
worker_service.resources.limits["gpu"] = str(tp_size)
if (
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].resources.limits
is not None
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
# Explicitly cast `limits` as the typecheck cannot determine that
# limits is not None here
cast(
dict[str, str],
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].resources.limits,
)["gpu"] = str(tp_size)
args = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args
raise ValueError("Missing extraPodSpec or mainContainer in worker service")
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
......@@ -228,9 +275,7 @@ class VllmV1ConfigModifier:
except ValueError:
args = append_argument(args, ["--tensor-parallel-size", str(tp_size)])
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args = join_arguments(args)
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
return cfg.model_dump()
......@@ -238,7 +283,16 @@ class VllmV1ConfigModifier:
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
worker_name = WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
args = cfg.spec.services[worker_name].extraPodSpec.mainContainer.args
worker_service = cfg.spec.services[worker_name]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
logger.warning(
f"Worker service missing extraPodSpec or mainContainer, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
for i, arg in enumerate(args):
......@@ -253,12 +307,29 @@ class VllmV1ConfigModifier:
@classmethod
def get_port(cls, config: dict) -> int:
cfg = Config.model_validate(config)
args = cfg.spec.services["Frontend"].extraPodSpec.mainContainer.args
frontend_service = cfg.spec.services.get("Frontend")
if (
not frontend_service
or not frontend_service.extraPodSpec
or not frontend_service.extraPodSpec.mainContainer
):
logger.warning(
f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = frontend_service.extraPodSpec.mainContainer.args
if not args:
logger.warning(
f"No args found in Frontend configuration, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = break_arguments(args)
try:
idx = args.index("--http-port")
return int(args[idx + 1])
except ValueError:
except (ValueError, IndexError):
logger.warning(
f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
......@@ -311,9 +382,17 @@ class SGLangConfigModifier:
WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
]
args = cfg.spec.services[
worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args
]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in worker service"
)
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
......@@ -325,9 +404,7 @@ class SGLangConfigModifier:
if "--disable-radix-cache" not in args:
args = append_argument(args, "--disable-radix-cache")
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args = join_arguments(args)
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
elif target == "decode":
# delete prefill worker
......@@ -335,16 +412,20 @@ class SGLangConfigModifier:
WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
]
args = cfg.spec.services[
worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args
]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in worker service"
)
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
# call `dynamo.sglang.worker` instead of `dynamo.sglang.decode_worker`
idx = args.index("dynamo.sglang.decode_worker")
args[idx] = "dynamo.sglang.worker"
# remove `--disaggregation-mode` and `--disaggregation-transfer-backend`
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-transfer-backend")
......@@ -353,9 +434,7 @@ class SGLangConfigModifier:
if "--disable-radix-cache" in args:
args.remove("--disable-radix-cache")
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args = join_arguments(args)
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
# set num workers to 1
decode_worker_config = config["spec"]["services"][
......@@ -369,27 +448,30 @@ class SGLangConfigModifier:
def set_config_tp_size(cls, config: dict, tp_size: int):
cfg = Config.model_validate(config)
cfg.spec.services[
worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].resources.requests["gpu"] = str(tp_size)
]
# Ensure resources exists
if worker_service.resources is None:
worker_service.resources = ServiceResources()
# Ensure requests exists
if worker_service.resources.requests is None:
worker_service.resources.requests = {}
worker_service.resources.requests["gpu"] = str(tp_size)
# Update limits if they exist
if worker_service.resources.limits is not None:
worker_service.resources.limits["gpu"] = str(tp_size)
if (
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].resources.limits
is not None
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
# Explicitly cast `limits` as the typecheck cannot determine that
# limits is not None here
cast(
dict[str, str],
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].resources.limits,
)["gpu"] = str(tp_size)
args = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args
raise ValueError("Missing extraPodSpec or mainContainer in worker service")
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
......@@ -399,9 +481,7 @@ class SGLangConfigModifier:
except ValueError:
args = append_argument(args, ["--tp", str(tp_size)])
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args = join_arguments(args)
worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
return cfg.model_dump()
......@@ -409,7 +489,16 @@ class SGLangConfigModifier:
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
worker_name = WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
args = cfg.spec.services[worker_name].extraPodSpec.mainContainer.args
worker_service = cfg.spec.services[worker_name]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
logger.warning(
f"Worker service missing extraPodSpec or mainContainer, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args)
for i, arg in enumerate(args):
......@@ -424,12 +513,29 @@ class SGLangConfigModifier:
@classmethod
def get_port(cls, config: dict) -> int:
cfg = Config.model_validate(config)
args = cfg.spec.services["Frontend"].extraPodSpec.mainContainer.args
frontend_service = cfg.spec.services.get("Frontend")
if (
not frontend_service
or not frontend_service.extraPodSpec
or not frontend_service.extraPodSpec.mainContainer
):
logger.warning(
f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = frontend_service.extraPodSpec.mainContainer.args
if not args:
logger.warning(
f"No args found in Frontend configuration, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = break_arguments(args)
try:
idx = args.index("--http-port")
return int(args[idx + 1])
except ValueError:
except (ValueError, IndexError):
logger.warning(
f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
......@@ -451,7 +557,10 @@ class SGLangConfigModifier:
return 0
CONFIG_MODIFIERS = {
CONFIG_MODIFIERS: dict[str, type[ConfigModifierProtocol]] = {
"vllm": VllmV1ConfigModifier,
"sglang": SGLangConfigModifier,
}
# Re-export WORKER_COMPONENT_NAMES for profile_sla.py
__all__ = ["CONFIG_MODIFIERS", "WORKER_COMPONENT_NAMES"]
......@@ -4,8 +4,9 @@
import logging
import numpy as np
from utils.genai_perf import benchmark_decode
from utils.plot import plot_decode_3d_surface
from benchmarks.profiler.utils.genai_perf import benchmark_decode
from benchmarks.profiler.utils.plot import plot_decode_3d_surface
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
......
......@@ -4,8 +4,9 @@
import logging
import numpy as np
from utils.genai_perf import benchmark_prefill
from utils.plot import plot_prefill_interpolation
from benchmarks.profiler.utils.genai_perf import benchmark_prefill
from benchmarks.profiler.utils.plot import plot_prefill_interpolation
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the profiler module."""
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Test suite for profile_sla dry-run functionality.
This test ensures that the profile_sla script can successfully run in dry-run mode
for both vllm and sglang backends with their respective disagg.yaml configurations.
"""
import sys
from pathlib import Path
import pytest
# Add the project root to sys.path to enable imports
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from benchmarks.profiler.profile_sla import run_profile # noqa: E402
class TestProfileSLADryRun:
"""Test class for profile_sla dry-run functionality."""
@pytest.fixture
def vllm_args(self):
"""Create arguments for vllm backend dry-run test."""
class Args:
backend = "vllm"
config = "components/backends/vllm/deploy/disagg.yaml"
output_dir = "/tmp/test_profiling_results"
namespace = "test-namespace"
min_num_gpus_per_engine = 1
max_num_gpus_per_engine = 8
skip_existing_results = False
force_rerun = False
isl = 3000
osl = 500
ttft = 50
itl = 10
max_context_length = 16384
prefill_interpolation_granularity = 16
decode_interpolation_granularity = 6
service_name = ""
dry_run = True
return Args()
@pytest.fixture
def sglang_args(self):
"""Create arguments for sglang backend dry-run test."""
class Args:
backend = "sglang"
config = "components/backends/sglang/deploy/disagg.yaml"
output_dir = "/tmp/test_profiling_results"
namespace = "test-namespace"
min_num_gpus_per_engine = 1
max_num_gpus_per_engine = 8
skip_existing_results = False
force_rerun = False
isl = 3000
osl = 500
ttft = 50
itl = 10
max_context_length = 16384
prefill_interpolation_granularity = 16
decode_interpolation_granularity = 6
service_name = ""
dry_run = True
return Args()
@pytest.mark.pre_merge
@pytest.mark.asyncio
async def test_vllm_dryrun(self, vllm_args):
"""Test that profile_sla dry-run works for vllm backend with disagg.yaml config."""
# Run the profile in dry-run mode - should complete without errors
await run_profile(vllm_args)
@pytest.mark.pre_merge
@pytest.mark.asyncio
async def test_sglang_dryrun(self, sglang_args):
"""Test that profile_sla dry-run works for sglang backend with disagg.yaml config."""
# Run the profile in dry-run mode - should complete without errors
await run_profile(sglang_args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment