"vscode:/vscode.git/clone" did not exist on "4adab52df27552cb58108f8525dcabf3dc29b4db"
Unverified Commit 882ae1b4 authored by Michael Shin's avatar Michael Shin Committed by GitHub
Browse files

chore: add pydantic validation for benchmark profile config (#2791)


Signed-off-by: default avatarMichael Shin <michaelshin@users.noreply.github.com>
parent 5bbbeae3
......@@ -15,9 +15,9 @@
import logging
import re
from copy import deepcopy
from typing import Literal
from typing import Literal, Optional, cast
from pydantic import BaseModel
from utils.defaults import DEFAULT_MODEL_NAME, DYNAMO_RUN_DEFAULT_PORT
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
......@@ -33,6 +33,43 @@ console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
class Container(BaseModel):
args: list[str] = []
class PodSpec(BaseModel):
mainContainer: Container
class ServiceResources(BaseModel):
requests: dict[str, str]
limits: Optional[dict[str, str]] = None
class Service(BaseModel):
replicas: int
resources: ServiceResources
extraPodSpec: PodSpec
class Services(BaseModel):
Frontend: Service
__root__: dict[str, Service]
class Spec(BaseModel):
services: dict[str, Service]
class Metadata(BaseModel):
name: str
class Config(BaseModel):
metadata: Metadata
spec: Spec
def break_arguments(args: list[str]) -> list[str]:
ans = []
if isinstance(args, str):
......@@ -88,29 +125,29 @@ def find_arg_index(args: list[str]) -> int:
class VllmV1ConfigModifier:
@classmethod
def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
config = deepcopy(config)
cfg = Config.model_validate(config)
# set metadata name
config["metadata"]["name"] = "vllm-agg"
cfg.metadata.name = "vllm-agg"
# disable planner
if "Planner" in config["spec"]["services"]:
del config["spec"]["services"]["Planner"]
if "Planner" in cfg.spec.services:
del cfg.spec.services["Planner"]
if target == "prefill":
# convert prefill worker into decode worker
config["spec"]["services"][
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
] = config["spec"]["services"][
] = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
]
del config["spec"]["services"][
del cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
]
args = config["spec"]["services"][
args = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"]
].extraPodSpec.mainContainer.args
args = break_arguments(args)
......@@ -123,19 +160,19 @@ class VllmV1ConfigModifier:
if "--no-enable-prefix-caching" not in args:
args = append_argument(args, "--no-enable-prefix-caching")
config["spec"]["services"][
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args)
].extraPodSpec.mainContainer.args = join_arguments(args)
elif target == "decode":
# delete prefill worker
del config["spec"]["services"][
del cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
]
args = config["spec"]["services"][
args = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"]
].extraPodSpec.mainContainer.args
args = break_arguments(args)
......@@ -145,38 +182,43 @@ class VllmV1ConfigModifier:
if "--no-enable-prefix-caching" in args:
args.remove("--no-enable-prefix-caching")
config["spec"]["services"][
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args)
].extraPodSpec.mainContainer.args = join_arguments(args)
# set num workers to 1
decode_worker_config = config["spec"]["services"][
decode_worker_config = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]
decode_worker_config["replicas"] = 1
decode_worker_config.replicas = 1
return config
return cfg.model_dump()
@classmethod
def set_config_tp_size(cls, config: dict, tp_size: int):
config = deepcopy(config)
cfg = Config.model_validate(config)
config["spec"]["services"][
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["resources"]["requests"]["gpu"] = str(tp_size)
].resources.requests["gpu"] = str(tp_size)
if (
"limits"
in config["spec"]["services"][
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["resources"]
].resources.limits
is not None
):
config["spec"]["services"][
# Explicitly cast `limits` as the typecheck cannot determine that
# limits is not None here
cast(
dict[str, str],
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["resources"]["limits"]["gpu"] = str(tp_size)
].resources.limits,
)["gpu"] = str(tp_size)
args = config["spec"]["services"][
args = cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"]
].extraPodSpec.mainContainer.args
args = break_arguments(args)
......@@ -186,18 +228,17 @@ class VllmV1ConfigModifier:
except ValueError:
args = append_argument(args, ["--tensor-parallel-size", str(tp_size)])
config["spec"]["services"][
cfg.spec.services[
WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args)
].extraPodSpec.mainContainer.args = join_arguments(args)
return config
return cfg.model_dump()
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
worker_name = WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
args = config["spec"]["services"][worker_name]["extraPodSpec"]["mainContainer"][
"args"
]
args = cfg.spec.services[worker_name].extraPodSpec.mainContainer.args
args = break_arguments(args)
for i, arg in enumerate(args):
......@@ -211,9 +252,8 @@ class VllmV1ConfigModifier:
@classmethod
def get_port(cls, config: dict) -> int:
args = config["spec"]["services"]["Frontend"]["extraPodSpec"]["mainContainer"][
"args"
]
cfg = Config.model_validate(config)
args = cfg.spec.services["Frontend"].extraPodSpec.mainContainer.args
args = break_arguments(args)
try:
idx = args.index("--http-port")
......@@ -251,29 +291,29 @@ class VllmV1ConfigModifier:
class SGLangConfigModifier:
@classmethod
def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
config = deepcopy(config)
cfg = Config.model_validate(config)
# set metadata name
config["metadata"]["name"] = "sglang-agg"
cfg.metadata.name = "sglang-agg"
# disable planner
if "Planner" in config["spec"]["services"]:
del config["spec"]["services"]["Planner"]
if "Planner" in cfg.spec.services:
del cfg.spec.services["Planner"]
if target == "prefill":
# convert prefill worker into decode worker
config["spec"]["services"][
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
] = config["spec"]["services"][
] = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
]
del config["spec"]["services"][
del cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
]
args = config["spec"]["services"][
args = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"]
].extraPodSpec.mainContainer.args
args = break_arguments(args)
......@@ -285,19 +325,19 @@ class SGLangConfigModifier:
if "--disable-radix-cache" not in args:
args = append_argument(args, "--disable-radix-cache")
config["spec"]["services"][
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args)
].extraPodSpec.mainContainer.args = join_arguments(args)
elif target == "decode":
# delete prefill worker
del config["spec"]["services"][
del cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
]
args = config["spec"]["services"][
args = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"]
].extraPodSpec.mainContainer.args
args = break_arguments(args)
......@@ -313,9 +353,9 @@ class SGLangConfigModifier:
if "--disable-radix-cache" in args:
args.remove("--disable-radix-cache")
config["spec"]["services"][
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args)
].extraPodSpec.mainContainer.args = join_arguments(args)
# set num workers to 1
decode_worker_config = config["spec"]["services"][
......@@ -327,24 +367,29 @@ class SGLangConfigModifier:
@classmethod
def set_config_tp_size(cls, config: dict, tp_size: int):
config = deepcopy(config)
cfg = Config.model_validate(config)
config["spec"]["services"][
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["resources"]["requests"]["gpu"] = str(tp_size)
].resources.requests["gpu"] = str(tp_size)
if (
"limits"
in config["spec"]["services"][
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["resources"]
].resources.limits
is not None
):
config["spec"]["services"][
# Explicitly cast `limits` as the typecheck cannot determine that
# limits is not None here
cast(
dict[str, str],
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["resources"]["limits"]["gpu"] = str(tp_size)
].resources.limits,
)["gpu"] = str(tp_size)
args = config["spec"]["services"][
args = cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"]
].extraPodSpec.mainContainer.args
args = break_arguments(args)
......@@ -354,18 +399,17 @@ class SGLangConfigModifier:
except ValueError:
args = append_argument(args, ["--tp", str(tp_size)])
config["spec"]["services"][
cfg.spec.services[
WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args)
].extraPodSpec.mainContainer.args = join_arguments(args)
return config
return cfg.model_dump()
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
worker_name = WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
args = config["spec"]["services"][worker_name]["extraPodSpec"]["mainContainer"][
"args"
]
args = cfg.spec.services[worker_name].extraPodSpec.mainContainer.args
args = break_arguments(args)
for i, arg in enumerate(args):
......@@ -379,9 +423,8 @@ class SGLangConfigModifier:
@classmethod
def get_port(cls, config: dict) -> int:
args = config["spec"]["services"]["Frontend"]["extraPodSpec"]["mainContainer"][
"args"
]
cfg = Config.model_validate(config)
args = cfg.spec.services["Frontend"].extraPodSpec.mainContainer.args
args = break_arguments(args)
try:
idx = args.index("--http-port")
......
......@@ -42,6 +42,7 @@ classifiers = [
dependencies = [
"networkx",
"pandas",
"pydantic>=2",
"tabulate",
"types-tabulate",
"transformers",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment