Unverified Commit 882ae1b4 authored by Michael Shin's avatar Michael Shin Committed by GitHub
Browse files

chore: add pydantic validation for benchmark profile config (#2791)


Signed-off-by: default avatarMichael Shin <michaelshin@users.noreply.github.com>
parent 5bbbeae3
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
import logging import logging
import re import re
from copy import deepcopy from typing import Literal, Optional, cast
from typing import Literal
from pydantic import BaseModel
from utils.defaults import DEFAULT_MODEL_NAME, DYNAMO_RUN_DEFAULT_PORT from utils.defaults import DEFAULT_MODEL_NAME, DYNAMO_RUN_DEFAULT_PORT
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
...@@ -33,6 +33,43 @@ console_handler.setFormatter(formatter) ...@@ -33,6 +33,43 @@ console_handler.setFormatter(formatter)
logger.addHandler(console_handler) logger.addHandler(console_handler)
class Container(BaseModel):
args: list[str] = []
class PodSpec(BaseModel):
mainContainer: Container
class ServiceResources(BaseModel):
requests: dict[str, str]
limits: Optional[dict[str, str]] = None
class Service(BaseModel):
replicas: int
resources: ServiceResources
extraPodSpec: PodSpec
class Services(BaseModel):
Frontend: Service
__root__: dict[str, Service]
class Spec(BaseModel):
services: dict[str, Service]
class Metadata(BaseModel):
name: str
class Config(BaseModel):
metadata: Metadata
spec: Spec
def break_arguments(args: list[str]) -> list[str]: def break_arguments(args: list[str]) -> list[str]:
ans = [] ans = []
if isinstance(args, str): if isinstance(args, str):
...@@ -88,29 +125,29 @@ def find_arg_index(args: list[str]) -> int: ...@@ -88,29 +125,29 @@ def find_arg_index(args: list[str]) -> int:
class VllmV1ConfigModifier: class VllmV1ConfigModifier:
@classmethod @classmethod
def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict: def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
config = deepcopy(config) cfg = Config.model_validate(config)
# set metadata name # set metadata name
config["metadata"]["name"] = "vllm-agg" cfg.metadata.name = "vllm-agg"
# disable planner # disable planner
if "Planner" in config["spec"]["services"]: if "Planner" in cfg.spec.services:
del config["spec"]["services"]["Planner"] del cfg.spec.services["Planner"]
if target == "prefill": if target == "prefill":
# convert prefill worker into decode worker # convert prefill worker into decode worker
config["spec"]["services"][ cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
] = config["spec"]["services"][ ] = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
] ]
del config["spec"]["services"][ del cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
] ]
args = config["spec"]["services"][ args = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] ].extraPodSpec.mainContainer.args
args = break_arguments(args) args = break_arguments(args)
...@@ -123,19 +160,19 @@ class VllmV1ConfigModifier: ...@@ -123,19 +160,19 @@ class VllmV1ConfigModifier:
if "--no-enable-prefix-caching" not in args: if "--no-enable-prefix-caching" not in args:
args = append_argument(args, "--no-enable-prefix-caching") args = append_argument(args, "--no-enable-prefix-caching")
config["spec"]["services"][ cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args) ].extraPodSpec.mainContainer.args = join_arguments(args)
elif target == "decode": elif target == "decode":
# delete prefill worker # delete prefill worker
del config["spec"]["services"][ del cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
] ]
args = config["spec"]["services"][ args = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] ].extraPodSpec.mainContainer.args
args = break_arguments(args) args = break_arguments(args)
...@@ -145,38 +182,43 @@ class VllmV1ConfigModifier: ...@@ -145,38 +182,43 @@ class VllmV1ConfigModifier:
if "--no-enable-prefix-caching" in args: if "--no-enable-prefix-caching" in args:
args.remove("--no-enable-prefix-caching") args.remove("--no-enable-prefix-caching")
config["spec"]["services"][ cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args) ].extraPodSpec.mainContainer.args = join_arguments(args)
# set num workers to 1 # set num workers to 1
decode_worker_config = config["spec"]["services"][ decode_worker_config = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
] ]
decode_worker_config["replicas"] = 1 decode_worker_config.replicas = 1
return config return cfg.model_dump()
@classmethod @classmethod
def set_config_tp_size(cls, config: dict, tp_size: int): def set_config_tp_size(cls, config: dict, tp_size: int):
config = deepcopy(config) cfg = Config.model_validate(config)
config["spec"]["services"][ cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["resources"]["requests"]["gpu"] = str(tp_size) ].resources.requests["gpu"] = str(tp_size)
if ( if (
"limits" cfg.spec.services[
in config["spec"]["services"][
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["resources"] ].resources.limits
is not None
): ):
config["spec"]["services"][ # Explicitly cast `limits` as the typecheck cannot determine that
# limits is not None here
cast(
dict[str, str],
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["resources"]["limits"]["gpu"] = str(tp_size) ].resources.limits,
)["gpu"] = str(tp_size)
args = config["spec"]["services"][ args = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] ].extraPodSpec.mainContainer.args
args = break_arguments(args) args = break_arguments(args)
...@@ -186,18 +228,17 @@ class VllmV1ConfigModifier: ...@@ -186,18 +228,17 @@ class VllmV1ConfigModifier:
except ValueError: except ValueError:
args = append_argument(args, ["--tensor-parallel-size", str(tp_size)]) args = append_argument(args, ["--tensor-parallel-size", str(tp_size)])
config["spec"]["services"][ cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args) ].extraPodSpec.mainContainer.args = join_arguments(args)
return config return cfg.model_dump()
@classmethod @classmethod
def get_model_name(cls, config: dict) -> str: def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
worker_name = WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name worker_name = WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
args = config["spec"]["services"][worker_name]["extraPodSpec"]["mainContainer"][ args = cfg.spec.services[worker_name].extraPodSpec.mainContainer.args
"args"
]
args = break_arguments(args) args = break_arguments(args)
for i, arg in enumerate(args): for i, arg in enumerate(args):
...@@ -211,9 +252,8 @@ class VllmV1ConfigModifier: ...@@ -211,9 +252,8 @@ class VllmV1ConfigModifier:
@classmethod @classmethod
def get_port(cls, config: dict) -> int: def get_port(cls, config: dict) -> int:
args = config["spec"]["services"]["Frontend"]["extraPodSpec"]["mainContainer"][ cfg = Config.model_validate(config)
"args" args = cfg.spec.services["Frontend"].extraPodSpec.mainContainer.args
]
args = break_arguments(args) args = break_arguments(args)
try: try:
idx = args.index("--http-port") idx = args.index("--http-port")
...@@ -251,29 +291,29 @@ class VllmV1ConfigModifier: ...@@ -251,29 +291,29 @@ class VllmV1ConfigModifier:
class SGLangConfigModifier: class SGLangConfigModifier:
@classmethod @classmethod
def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict: def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
config = deepcopy(config) cfg = Config.model_validate(config)
# set metadata name # set metadata name
config["metadata"]["name"] = "sglang-agg" cfg.metadata.name = "sglang-agg"
# disable planner # disable planner
if "Planner" in config["spec"]["services"]: if "Planner" in cfg.spec.services:
del config["spec"]["services"]["Planner"] del cfg.spec.services["Planner"]
if target == "prefill": if target == "prefill":
# convert prefill worker into decode worker # convert prefill worker into decode worker
config["spec"]["services"][ cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
] = config["spec"]["services"][ ] = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
] ]
del config["spec"]["services"][ del cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
] ]
args = config["spec"]["services"][ args = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] ].extraPodSpec.mainContainer.args
args = break_arguments(args) args = break_arguments(args)
...@@ -285,19 +325,19 @@ class SGLangConfigModifier: ...@@ -285,19 +325,19 @@ class SGLangConfigModifier:
if "--disable-radix-cache" not in args: if "--disable-radix-cache" not in args:
args = append_argument(args, "--disable-radix-cache") args = append_argument(args, "--disable-radix-cache")
config["spec"]["services"][ cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args) ].extraPodSpec.mainContainer.args = join_arguments(args)
elif target == "decode": elif target == "decode":
# delete prefill worker # delete prefill worker
del config["spec"]["services"][ del cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
] ]
args = config["spec"]["services"][ args = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] ].extraPodSpec.mainContainer.args
args = break_arguments(args) args = break_arguments(args)
...@@ -313,9 +353,9 @@ class SGLangConfigModifier: ...@@ -313,9 +353,9 @@ class SGLangConfigModifier:
if "--disable-radix-cache" in args: if "--disable-radix-cache" in args:
args.remove("--disable-radix-cache") args.remove("--disable-radix-cache")
config["spec"]["services"][ cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args) ].extraPodSpec.mainContainer.args = join_arguments(args)
# set num workers to 1 # set num workers to 1
decode_worker_config = config["spec"]["services"][ decode_worker_config = config["spec"]["services"][
...@@ -327,24 +367,29 @@ class SGLangConfigModifier: ...@@ -327,24 +367,29 @@ class SGLangConfigModifier:
@classmethod @classmethod
def set_config_tp_size(cls, config: dict, tp_size: int): def set_config_tp_size(cls, config: dict, tp_size: int):
config = deepcopy(config) cfg = Config.model_validate(config)
config["spec"]["services"][ cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["resources"]["requests"]["gpu"] = str(tp_size) ].resources.requests["gpu"] = str(tp_size)
if ( if (
"limits" cfg.spec.services[
in config["spec"]["services"][
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["resources"] ].resources.limits
is not None
): ):
config["spec"]["services"][ # Explicitly cast `limits` as the typecheck cannot determine that
# limits is not None here
cast(
dict[str, str],
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["resources"]["limits"]["gpu"] = str(tp_size) ].resources.limits,
)["gpu"] = str(tp_size)
args = config["spec"]["services"][ args = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] ].extraPodSpec.mainContainer.args
args = break_arguments(args) args = break_arguments(args)
...@@ -354,18 +399,17 @@ class SGLangConfigModifier: ...@@ -354,18 +399,17 @@ class SGLangConfigModifier:
except ValueError: except ValueError:
args = append_argument(args, ["--tp", str(tp_size)]) args = append_argument(args, ["--tp", str(tp_size)])
config["spec"]["services"][ cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args) ].extraPodSpec.mainContainer.args = join_arguments(args)
return config return cfg.model_dump()
@classmethod @classmethod
def get_model_name(cls, config: dict) -> str: def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
worker_name = WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name worker_name = WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
args = config["spec"]["services"][worker_name]["extraPodSpec"]["mainContainer"][ args = cfg.spec.services[worker_name].extraPodSpec.mainContainer.args
"args"
]
args = break_arguments(args) args = break_arguments(args)
for i, arg in enumerate(args): for i, arg in enumerate(args):
...@@ -379,9 +423,8 @@ class SGLangConfigModifier: ...@@ -379,9 +423,8 @@ class SGLangConfigModifier:
@classmethod @classmethod
def get_port(cls, config: dict) -> int: def get_port(cls, config: dict) -> int:
args = config["spec"]["services"]["Frontend"]["extraPodSpec"]["mainContainer"][ cfg = Config.model_validate(config)
"args" args = cfg.spec.services["Frontend"].extraPodSpec.mainContainer.args
]
args = break_arguments(args) args = break_arguments(args)
try: try:
idx = args.index("--http-port") idx = args.index("--http-port")
......
...@@ -42,6 +42,7 @@ classifiers = [ ...@@ -42,6 +42,7 @@ classifiers = [
dependencies = [ dependencies = [
"networkx", "networkx",
"pandas", "pandas",
"pydantic>=2",
"tabulate", "tabulate",
"types-tabulate", "types-tabulate",
"transformers", "transformers",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment