Unverified Commit f72dc01d authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

test: add test for pre-deployment script (#2857)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent 23ede721
...@@ -21,19 +21,22 @@ import os ...@@ -21,19 +21,22 @@ import os
import numpy as np import numpy as np
import yaml import yaml
from utils.config import CONFIG_MODIFIERS, WORKER_COMPONENT_NAMES
from utils.defaults import DECODE_NUM_REQUESTS_RANGE from benchmarks.profiler.utils.config import CONFIG_MODIFIERS, WORKER_COMPONENT_NAMES
from utils.genai_perf import benchmark_decode, benchmark_prefill from benchmarks.profiler.utils.defaults import DECODE_NUM_REQUESTS_RANGE
from utils.plot import plot_decode_performance, plot_prefill_performance from benchmarks.profiler.utils.genai_perf import benchmark_decode, benchmark_prefill
from utils.profile_cache import ( from benchmarks.profiler.utils.plot import (
plot_decode_performance,
plot_prefill_performance,
)
from benchmarks.profiler.utils.profile_cache import (
check_decode_results_exist, check_decode_results_exist,
check_prefill_results_exist, check_prefill_results_exist,
load_existing_decode_results, load_existing_decode_results,
load_existing_prefill_results, load_existing_prefill_results,
) )
from utils.profile_decode import profile_decode from benchmarks.profiler.utils.profile_decode import profile_decode
from utils.profile_prefill import profile_prefill from benchmarks.profiler.utils.profile_prefill import profile_prefill
from deploy.utils.dynamo_deployment import ( from deploy.utils.dynamo_deployment import (
DynamoDeploymentClient, DynamoDeploymentClient,
cleanup_remaining_deployments, cleanup_remaining_deployments,
...@@ -57,18 +60,6 @@ async def run_profile(args): ...@@ -57,18 +60,6 @@ async def run_profile(args):
try: try:
config_modifier = CONFIG_MODIFIERS[args.backend] config_modifier = CONFIG_MODIFIERS[args.backend]
if args.example_dir is None:
logger.info(
"Example directory not provided, inferring from config file location..."
)
try:
args.example_dir = os.path.dirname(os.path.dirname(args.config))
except Exception:
logger.error(
"Failed to infer example directory, please provide explicitly using --example-dir <path-to-example-dir>"
)
exit(1)
with open(args.config, "r") as f: with open(args.config, "r") as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
...@@ -134,6 +125,9 @@ async def run_profile(args): ...@@ -134,6 +125,9 @@ async def run_profile(args):
with open(prefill_config_fn, "w") as f: with open(prefill_config_fn, "w") as f:
yaml.dump(prefill_config, f) yaml.dump(prefill_config, f)
if args.dry_run:
logger.info("Skipping deployment creation in dry run mode")
else:
client = DynamoDeploymentClient( client = DynamoDeploymentClient(
namespace=args.namespace, namespace=args.namespace,
base_log_dir=work_dir, base_log_dir=work_dir,
...@@ -246,6 +240,9 @@ async def run_profile(args): ...@@ -246,6 +240,9 @@ async def run_profile(args):
with open(decode_config_fn, "w") as f: with open(decode_config_fn, "w") as f:
yaml.dump(decode_config, f) yaml.dump(decode_config, f)
if args.dry_run:
logger.info("Skipping deployment creation in dry run mode")
else:
client = DynamoDeploymentClient( client = DynamoDeploymentClient(
namespace=args.namespace, namespace=args.namespace,
base_log_dir=work_dir, base_log_dir=work_dir,
...@@ -318,11 +315,13 @@ async def run_profile(args): ...@@ -318,11 +315,13 @@ async def run_profile(args):
if decode_results: if decode_results:
plot_decode_performance(decode_results, args.itl, args.output_dir) plot_decode_performance(decode_results, args.itl, args.output_dir)
if args.dry_run:
logger.info("Skipping recommendations in dry run mode")
else:
logger.info("Analyzing results and generate recommendations...") logger.info("Analyzing results and generate recommendations...")
# Safety guards: no results → exit early with a clear message # Safety guards: no results → exit early with a clear message
if not (prefill_tp_size and prefill_ttft and prefill_thpt_per_gpu): if not (prefill_tp_size and prefill_ttft and prefill_thpt_per_gpu):
logger.error("No prefill results produced; skipping recommendations.") logger.error("No prefill results produced; skipping recommendations.")
return
# select best tp size for prefill # select best tp size for prefill
if min(prefill_ttft) > args.ttft: if min(prefill_ttft) > args.ttft:
...@@ -370,7 +369,9 @@ async def run_profile(args): ...@@ -370,7 +369,9 @@ async def run_profile(args):
) )
selected_decode_idx = int(np.argmin(np.array(decode_itl))) selected_decode_idx = int(np.argmin(np.array(decode_itl)))
else: else:
valid_indices = [i for i, itl in enumerate(decode_itl) if itl <= args.itl] valid_indices = [
i for i, itl in enumerate(decode_itl) if itl <= args.itl
]
# Among valid TP sizes, select the one with highest throughput per GPU # Among valid TP sizes, select the one with highest throughput per GPU
valid_thpts = [decode_thpt_per_gpu[i] for i in valid_indices] valid_thpts = [decode_thpt_per_gpu[i] for i in valid_indices]
max_thpt_idx = valid_indices[int(np.argmax(valid_thpts))] max_thpt_idx = valid_indices[int(np.argmax(valid_thpts))]
...@@ -390,6 +391,13 @@ async def run_profile(args): ...@@ -390,6 +391,13 @@ async def run_profile(args):
f"Suggested planner upper/lower bound for decode kv cache utilization: {min(1, selected_decode_kv_cache_utilization + 0.2):.2f}/{max(0.1, selected_decode_kv_cache_utilization - 0.2):.2f}" f"Suggested planner upper/lower bound for decode kv cache utilization: {min(1, selected_decode_kv_cache_utilization + 0.2):.2f}/{max(0.1, selected_decode_kv_cache_utilization - 0.2):.2f}"
) )
if args.dry_run:
# use min value for prefill and decode TP sizes
prefill_tp_size = [args.min_num_gpus_per_engine]
decode_tp_size = [args.min_num_gpus_per_engine]
selected_prefill_idx = 0
selected_decode_idx = 0
# interpolate ISL - TTFT with best prefill TP # interpolate ISL - TTFT with best prefill TP
best_prefill_tp = prefill_tp_size[selected_prefill_idx] best_prefill_tp = prefill_tp_size[selected_prefill_idx]
logger.info( logger.info(
...@@ -408,6 +416,9 @@ async def run_profile(args): ...@@ -408,6 +416,9 @@ async def run_profile(args):
with open(prefill_config_fn, "w") as f: with open(prefill_config_fn, "w") as f:
yaml.dump(prefill_config, f) yaml.dump(prefill_config, f)
if args.dry_run:
logger.info("Skipping deployment creation in dry run mode")
else:
client = DynamoDeploymentClient( client = DynamoDeploymentClient(
namespace=args.namespace, namespace=args.namespace,
base_log_dir=work_dir, base_log_dir=work_dir,
...@@ -468,6 +479,9 @@ async def run_profile(args): ...@@ -468,6 +479,9 @@ async def run_profile(args):
with open(decode_config_fn, "w") as f: with open(decode_config_fn, "w") as f:
yaml.dump(decode_config, f) yaml.dump(decode_config, f)
if args.dry_run:
logger.info("Skipping deployment creation in dry run mode")
else:
client = DynamoDeploymentClient( client = DynamoDeploymentClient(
namespace=args.namespace, namespace=args.namespace,
base_log_dir=work_dir, base_log_dir=work_dir,
...@@ -543,12 +557,6 @@ if __name__ == "__main__": ...@@ -543,12 +557,6 @@ if __name__ == "__main__":
required=True, required=True,
help="Path to the DynamoGraphDeployment config file", help="Path to the DynamoGraphDeployment config file",
) )
parser.add_argument(
"--example-dir",
type=str,
default=None,
help="path to the example directory, if not provided, will try to infer from config file location",
)
parser.add_argument( parser.add_argument(
"--output-dir", "--output-dir",
type=str, type=str,
...@@ -614,6 +622,11 @@ if __name__ == "__main__": ...@@ -614,6 +622,11 @@ if __name__ == "__main__":
default="", default="",
help="Service name for port forwarding (default: {deployment_name}-frontend)", help="Service name for port forwarding (default: {deployment_name}-frontend)",
) )
parser.add_argument(
"--dry-run",
action="store_true",
help="Dry run the profile job",
)
args = parser.parse_args() args = parser.parse_args()
asyncio.run(run_profile(args)) asyncio.run(run_profile(args))
...@@ -15,11 +15,14 @@ ...@@ -15,11 +15,14 @@
import logging import logging
import re import re
from typing import Literal, Optional, cast from typing import Literal, Optional, Protocol
from pydantic import BaseModel from pydantic import BaseModel
from utils.defaults import DEFAULT_MODEL_NAME, DYNAMO_RUN_DEFAULT_PORT
from benchmarks.profiler.utils.defaults import (
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
)
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -34,27 +37,30 @@ logger.addHandler(console_handler) ...@@ -34,27 +37,30 @@ logger.addHandler(console_handler)
class Container(BaseModel): class Container(BaseModel):
args: list[str] = [] args: Optional[list[str]] = None
model_config = {"extra": "allow"}
class PodSpec(BaseModel): class PodSpec(BaseModel):
mainContainer: Container mainContainer: Optional[Container] = None
model_config = {"extra": "allow"}
class ServiceResources(BaseModel): class ServiceResources(BaseModel):
requests: dict[str, str] requests: Optional[dict[str, str]] = None
limits: Optional[dict[str, str]] = None limits: Optional[dict[str, str]] = None
class Service(BaseModel): class Service(BaseModel):
replicas: int replicas: Optional[int] = None
resources: ServiceResources resources: Optional[ServiceResources] = None
extraPodSpec: PodSpec extraPodSpec: Optional[PodSpec] = None
model_config = {"extra": "allow"}
class Services(BaseModel): class Services(BaseModel):
Frontend: Service Frontend: Service
__root__: dict[str, Service] model_config = {"extra": "allow"}
class Spec(BaseModel): class Spec(BaseModel):
...@@ -68,14 +74,18 @@ class Metadata(BaseModel): ...@@ -68,14 +74,18 @@ class Metadata(BaseModel):
class Config(BaseModel): class Config(BaseModel):
metadata: Metadata metadata: Metadata
spec: Spec spec: Spec
model_config = {"extra": "allow"}
def break_arguments(args: list[str]) -> list[str]: def break_arguments(args: list[str] | None) -> list[str]:
ans = [] ans: list[str] = []
if args is None:
return ans
if isinstance(args, str): if isinstance(args, str):
ans = re.split(r"[ =]", args) ans = re.split(r"[ =]", args)
else: else:
for arg in args: for arg in args:
if arg is not None:
ans.extend(arg.split(" ")) ans.extend(arg.split(" "))
return ans return ans
...@@ -122,6 +132,28 @@ def find_arg_index(args: list[str]) -> int: ...@@ -122,6 +132,28 @@ def find_arg_index(args: list[str]) -> int:
return idx return idx
class ConfigModifierProtocol(Protocol):
@classmethod
def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
...
@classmethod
def set_config_tp_size(cls, config: dict, tp_size: int) -> dict:
...
@classmethod
def get_model_name(cls, config: dict) -> str:
...
@classmethod
def get_port(cls, config: dict) -> int:
...
@classmethod
def get_kv_cache_size_from_dynamo_log(cls, dynamo_log_fn: str) -> int:
...
class VllmV1ConfigModifier: class VllmV1ConfigModifier:
@classmethod @classmethod
def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict: def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
...@@ -145,9 +177,17 @@ class VllmV1ConfigModifier: ...@@ -145,9 +177,17 @@ class VllmV1ConfigModifier:
WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
] ]
args = cfg.spec.services[ worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args ]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in worker service"
)
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args) args = break_arguments(args)
...@@ -160,9 +200,7 @@ class VllmV1ConfigModifier: ...@@ -160,9 +200,7 @@ class VllmV1ConfigModifier:
if "--no-enable-prefix-caching" not in args: if "--no-enable-prefix-caching" not in args:
args = append_argument(args, "--no-enable-prefix-caching") args = append_argument(args, "--no-enable-prefix-caching")
cfg.spec.services[ worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args = join_arguments(args)
elif target == "decode": elif target == "decode":
# delete prefill worker # delete prefill worker
...@@ -170,9 +208,17 @@ class VllmV1ConfigModifier: ...@@ -170,9 +208,17 @@ class VllmV1ConfigModifier:
WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
] ]
args = cfg.spec.services[ worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args ]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in worker service"
)
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args) args = break_arguments(args)
...@@ -182,9 +228,7 @@ class VllmV1ConfigModifier: ...@@ -182,9 +228,7 @@ class VllmV1ConfigModifier:
if "--no-enable-prefix-caching" in args: if "--no-enable-prefix-caching" in args:
args.remove("--no-enable-prefix-caching") args.remove("--no-enable-prefix-caching")
cfg.spec.services[ worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args = join_arguments(args)
# set num workers to 1 # set num workers to 1
decode_worker_config = cfg.spec.services[ decode_worker_config = cfg.spec.services[
...@@ -198,27 +242,30 @@ class VllmV1ConfigModifier: ...@@ -198,27 +242,30 @@ class VllmV1ConfigModifier:
def set_config_tp_size(cls, config: dict, tp_size: int): def set_config_tp_size(cls, config: dict, tp_size: int):
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
cfg.spec.services[ worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].resources.requests["gpu"] = str(tp_size) ]
# Ensure resources exists
if worker_service.resources is None:
worker_service.resources = ServiceResources()
# Ensure requests exists
if worker_service.resources.requests is None:
worker_service.resources.requests = {}
worker_service.resources.requests["gpu"] = str(tp_size)
# Update limits if they exist
if worker_service.resources.limits is not None:
worker_service.resources.limits["gpu"] = str(tp_size)
if ( if (
cfg.spec.services[ not worker_service.extraPodSpec
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name or not worker_service.extraPodSpec.mainContainer
].resources.limits
is not None
): ):
# Explicitly cast `limits` as the typecheck cannot determine that raise ValueError("Missing extraPodSpec or mainContainer in worker service")
# limits is not None here args = worker_service.extraPodSpec.mainContainer.args
cast(
dict[str, str],
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].resources.limits,
)["gpu"] = str(tp_size)
args = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args
args = break_arguments(args) args = break_arguments(args)
...@@ -228,9 +275,7 @@ class VllmV1ConfigModifier: ...@@ -228,9 +275,7 @@ class VllmV1ConfigModifier:
except ValueError: except ValueError:
args = append_argument(args, ["--tensor-parallel-size", str(tp_size)]) args = append_argument(args, ["--tensor-parallel-size", str(tp_size)])
cfg.spec.services[ worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args = join_arguments(args)
return cfg.model_dump() return cfg.model_dump()
...@@ -238,7 +283,16 @@ class VllmV1ConfigModifier: ...@@ -238,7 +283,16 @@ class VllmV1ConfigModifier:
def get_model_name(cls, config: dict) -> str: def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
worker_name = WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name worker_name = WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
args = cfg.spec.services[worker_name].extraPodSpec.mainContainer.args worker_service = cfg.spec.services[worker_name]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
logger.warning(
f"Worker service missing extraPodSpec or mainContainer, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args) args = break_arguments(args)
for i, arg in enumerate(args): for i, arg in enumerate(args):
...@@ -253,12 +307,29 @@ class VllmV1ConfigModifier: ...@@ -253,12 +307,29 @@ class VllmV1ConfigModifier:
@classmethod @classmethod
def get_port(cls, config: dict) -> int: def get_port(cls, config: dict) -> int:
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
args = cfg.spec.services["Frontend"].extraPodSpec.mainContainer.args frontend_service = cfg.spec.services.get("Frontend")
if (
not frontend_service
or not frontend_service.extraPodSpec
or not frontend_service.extraPodSpec.mainContainer
):
logger.warning(
f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = frontend_service.extraPodSpec.mainContainer.args
if not args:
logger.warning(
f"No args found in Frontend configuration, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = break_arguments(args) args = break_arguments(args)
try: try:
idx = args.index("--http-port") idx = args.index("--http-port")
return int(args[idx + 1]) return int(args[idx + 1])
except ValueError: except (ValueError, IndexError):
logger.warning( logger.warning(
f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}" f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
) )
...@@ -311,9 +382,17 @@ class SGLangConfigModifier: ...@@ -311,9 +382,17 @@ class SGLangConfigModifier:
WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
] ]
args = cfg.spec.services[ worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args ]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in worker service"
)
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args) args = break_arguments(args)
...@@ -325,9 +404,7 @@ class SGLangConfigModifier: ...@@ -325,9 +404,7 @@ class SGLangConfigModifier:
if "--disable-radix-cache" not in args: if "--disable-radix-cache" not in args:
args = append_argument(args, "--disable-radix-cache") args = append_argument(args, "--disable-radix-cache")
cfg.spec.services[ worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args = join_arguments(args)
elif target == "decode": elif target == "decode":
# delete prefill worker # delete prefill worker
...@@ -335,16 +412,20 @@ class SGLangConfigModifier: ...@@ -335,16 +412,20 @@ class SGLangConfigModifier:
WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
] ]
args = cfg.spec.services[ worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args ]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
raise ValueError(
"Missing extraPodSpec or mainContainer in worker service"
)
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args) args = break_arguments(args)
# call `dynamo.sglang.worker` instead of `dynamo.sglang.decode_worker`
idx = args.index("dynamo.sglang.decode_worker")
args[idx] = "dynamo.sglang.worker"
# remove `--disaggregation-mode` and `--disaggregation-transfer-backend` # remove `--disaggregation-mode` and `--disaggregation-transfer-backend`
args = remove_valued_arguments(args, "--disaggregation-mode") args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-transfer-backend") args = remove_valued_arguments(args, "--disaggregation-transfer-backend")
...@@ -353,9 +434,7 @@ class SGLangConfigModifier: ...@@ -353,9 +434,7 @@ class SGLangConfigModifier:
if "--disable-radix-cache" in args: if "--disable-radix-cache" in args:
args.remove("--disable-radix-cache") args.remove("--disable-radix-cache")
cfg.spec.services[ worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args = join_arguments(args)
# set num workers to 1 # set num workers to 1
decode_worker_config = config["spec"]["services"][ decode_worker_config = config["spec"]["services"][
...@@ -369,27 +448,30 @@ class SGLangConfigModifier: ...@@ -369,27 +448,30 @@ class SGLangConfigModifier:
def set_config_tp_size(cls, config: dict, tp_size: int): def set_config_tp_size(cls, config: dict, tp_size: int):
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
cfg.spec.services[ worker_service = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].resources.requests["gpu"] = str(tp_size) ]
# Ensure resources exists
if worker_service.resources is None:
worker_service.resources = ServiceResources()
# Ensure requests exists
if worker_service.resources.requests is None:
worker_service.resources.requests = {}
worker_service.resources.requests["gpu"] = str(tp_size)
# Update limits if they exist
if worker_service.resources.limits is not None:
worker_service.resources.limits["gpu"] = str(tp_size)
if ( if (
cfg.spec.services[ not worker_service.extraPodSpec
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name or not worker_service.extraPodSpec.mainContainer
].resources.limits
is not None
): ):
# Explicitly cast `limits` as the typecheck cannot determine that raise ValueError("Missing extraPodSpec or mainContainer in worker service")
# limits is not None here args = worker_service.extraPodSpec.mainContainer.args
cast(
dict[str, str],
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].resources.limits,
)["gpu"] = str(tp_size)
args = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args
args = break_arguments(args) args = break_arguments(args)
...@@ -399,9 +481,7 @@ class SGLangConfigModifier: ...@@ -399,9 +481,7 @@ class SGLangConfigModifier:
except ValueError: except ValueError:
args = append_argument(args, ["--tp", str(tp_size)]) args = append_argument(args, ["--tp", str(tp_size)])
cfg.spec.services[ worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
].extraPodSpec.mainContainer.args = join_arguments(args)
return cfg.model_dump() return cfg.model_dump()
...@@ -409,7 +489,16 @@ class SGLangConfigModifier: ...@@ -409,7 +489,16 @@ class SGLangConfigModifier:
def get_model_name(cls, config: dict) -> str: def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
worker_name = WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name worker_name = WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
args = cfg.spec.services[worker_name].extraPodSpec.mainContainer.args worker_service = cfg.spec.services[worker_name]
if (
not worker_service.extraPodSpec
or not worker_service.extraPodSpec.mainContainer
):
logger.warning(
f"Worker service missing extraPodSpec or mainContainer, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = worker_service.extraPodSpec.mainContainer.args
args = break_arguments(args) args = break_arguments(args)
for i, arg in enumerate(args): for i, arg in enumerate(args):
...@@ -424,12 +513,29 @@ class SGLangConfigModifier: ...@@ -424,12 +513,29 @@ class SGLangConfigModifier:
@classmethod @classmethod
def get_port(cls, config: dict) -> int: def get_port(cls, config: dict) -> int:
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
args = cfg.spec.services["Frontend"].extraPodSpec.mainContainer.args frontend_service = cfg.spec.services.get("Frontend")
if (
not frontend_service
or not frontend_service.extraPodSpec
or not frontend_service.extraPodSpec.mainContainer
):
logger.warning(
f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = frontend_service.extraPodSpec.mainContainer.args
if not args:
logger.warning(
f"No args found in Frontend configuration, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = break_arguments(args) args = break_arguments(args)
try: try:
idx = args.index("--http-port") idx = args.index("--http-port")
return int(args[idx + 1]) return int(args[idx + 1])
except ValueError: except (ValueError, IndexError):
logger.warning( logger.warning(
f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}" f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
) )
...@@ -451,7 +557,10 @@ class SGLangConfigModifier: ...@@ -451,7 +557,10 @@ class SGLangConfigModifier:
return 0 return 0
CONFIG_MODIFIERS = { CONFIG_MODIFIERS: dict[str, type[ConfigModifierProtocol]] = {
"vllm": VllmV1ConfigModifier, "vllm": VllmV1ConfigModifier,
"sglang": SGLangConfigModifier, "sglang": SGLangConfigModifier,
} }
# Re-export WORKER_COMPONENT_NAMES for profile_sla.py
__all__ = ["CONFIG_MODIFIERS", "WORKER_COMPONENT_NAMES"]
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
import logging import logging
import numpy as np import numpy as np
from utils.genai_perf import benchmark_decode
from utils.plot import plot_decode_3d_surface from benchmarks.profiler.utils.genai_perf import benchmark_decode
from benchmarks.profiler.utils.plot import plot_decode_3d_surface
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
import logging import logging
import numpy as np import numpy as np
from utils.genai_perf import benchmark_prefill
from utils.plot import plot_prefill_interpolation from benchmarks.profiler.utils.genai_perf import benchmark_prefill
from benchmarks.profiler.utils.plot import plot_prefill_interpolation
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the profiler module."""
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Test suite for profile_sla dry-run functionality.
This test ensures that the profile_sla script can successfully run in dry-run mode
for both vllm and sglang backends with their respective disagg.yaml configurations.
"""
import sys
from pathlib import Path
import pytest
# Add the project root to sys.path to enable imports
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from benchmarks.profiler.profile_sla import run_profile # noqa: E402
class TestProfileSLADryRun:
"""Test class for profile_sla dry-run functionality."""
@pytest.fixture
def vllm_args(self):
"""Create arguments for vllm backend dry-run test."""
class Args:
backend = "vllm"
config = "components/backends/vllm/deploy/disagg.yaml"
output_dir = "/tmp/test_profiling_results"
namespace = "test-namespace"
min_num_gpus_per_engine = 1
max_num_gpus_per_engine = 8
skip_existing_results = False
force_rerun = False
isl = 3000
osl = 500
ttft = 50
itl = 10
max_context_length = 16384
prefill_interpolation_granularity = 16
decode_interpolation_granularity = 6
service_name = ""
dry_run = True
return Args()
@pytest.fixture
def sglang_args(self):
"""Create arguments for sglang backend dry-run test."""
class Args:
backend = "sglang"
config = "components/backends/sglang/deploy/disagg.yaml"
output_dir = "/tmp/test_profiling_results"
namespace = "test-namespace"
min_num_gpus_per_engine = 1
max_num_gpus_per_engine = 8
skip_existing_results = False
force_rerun = False
isl = 3000
osl = 500
ttft = 50
itl = 10
max_context_length = 16384
prefill_interpolation_granularity = 16
decode_interpolation_granularity = 6
service_name = ""
dry_run = True
return Args()
@pytest.mark.pre_merge
@pytest.mark.asyncio
async def test_vllm_dryrun(self, vllm_args):
"""Test that profile_sla dry-run works for vllm backend with disagg.yaml config."""
# Run the profile in dry-run mode - should complete without errors
await run_profile(vllm_args)
@pytest.mark.pre_merge
@pytest.mark.asyncio
async def test_sglang_dryrun(self, sglang_args):
"""Test that profile_sla dry-run works for sglang backend with disagg.yaml config."""
# Run the profile in dry-run mode - should complete without errors
await run_profile(sglang_args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment