Unverified Commit 7e4bc716 authored by hhzhang16's avatar hhzhang16 Committed by GitHub
Browse files

feat: remove default model name in Profiler; validate for one of served model...


feat: remove default model name in Profiler; validate for one of served model name and model path in Profiler (#5950)
Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
parent 3a2885c8
......@@ -133,9 +133,7 @@ class BaseConfigModifier:
WORKER_SERVED_MODEL_NAME_ARG: str = "--served-model-name"
@classmethod
def _get_model_name_and_path_from_args(
cls, args: list[str], default_model_name: str, logger
) -> tuple[str, str]:
def _get_model_name_and_path_from_args(cls, args: list[str]) -> Tuple[str, str]:
"""
Extract model name and path from worker args.
......@@ -144,14 +142,14 @@ class BaseConfigModifier:
Args:
args: Broken argument list
default_model_name: Default to use if neither arg found
logger: Logger instance for warnings
Returns:
Tuple of (model_name, model_path)
"""
model_name = default_model_name
Raises:
ValueError: If neither --served-model-name nor model path arg is found
"""
model_name = None
# Check for --served-model-name first (API model name)
for i, arg in enumerate(args):
if arg == cls.WORKER_SERVED_MODEL_NAME_ARG and i + 1 < len(args):
......@@ -159,22 +157,26 @@ class BaseConfigModifier:
break
# Check for backend-specific path argument
model_path = model_name
model_path = None
for i, arg in enumerate(args):
if arg == cls.WORKER_MODEL_PATH_ARG and i + 1 < len(args):
model_path = args[i + 1]
break
# If model_name not found, use model_path as model_name
if model_name == default_model_name and model_path != default_model_name:
model_name = model_path
# Warn if neither argument was found
if model_name == default_model_name and model_path == default_model_name:
logger.warning(
f"Model name not found in configuration args, using default model name: {default_model_name}"
# Require at least one to be specified
if model_name is None and model_path is None:
raise ValueError(
f"Cannot determine model: neither {cls.WORKER_MODEL_PATH_ARG} nor "
f"{cls.WORKER_SERVED_MODEL_NAME_ARG} found in worker configuration. "
f"Please specify a model name/path in your config."
)
# If only one is specified, use it for both
if model_path is None:
model_path = model_name
elif model_name is None:
model_name = model_path
return model_name, model_path
@classmethod
......
......@@ -20,11 +20,7 @@ from benchmarks.profiler.utils.config import (
validate_and_get_worker_args,
)
from benchmarks.profiler.utils.config_modifiers.protocol import BaseConfigModifier
from benchmarks.profiler.utils.defaults import (
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
EngineType,
)
from benchmarks.profiler.utils.defaults import DYNAMO_RUN_DEFAULT_PORT, EngineType
from dynamo.planner.defaults import SubComponentType
logger = logging.getLogger(__name__)
......@@ -282,17 +278,10 @@ class SGLangConfigModifier(BaseConfigModifier):
@classmethod
def get_model_name(cls, config: dict) -> Tuple[str, str]:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(cfg, backend="sglang")
args = validate_and_get_worker_args(worker_service, backend="sglang")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME, DEFAULT_MODEL_NAME
worker_service = get_worker_service_from_config(cfg, backend="sglang")
args = validate_and_get_worker_args(worker_service, backend="sglang")
args = break_arguments(args)
return cls._get_model_name_and_path_from_args(args, DEFAULT_MODEL_NAME, logger)
return cls._get_model_name_and_path_from_args(args)
@classmethod
def get_port(cls, config: dict) -> int:
......
......@@ -21,11 +21,7 @@ from benchmarks.profiler.utils.config import (
validate_and_get_worker_args,
)
from benchmarks.profiler.utils.config_modifiers.protocol import BaseConfigModifier
from benchmarks.profiler.utils.defaults import (
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
EngineType,
)
from benchmarks.profiler.utils.defaults import DYNAMO_RUN_DEFAULT_PORT, EngineType
from dynamo.planner.defaults import SubComponentType
logger = logging.getLogger(__name__)
......@@ -256,17 +252,10 @@ class TrtllmConfigModifier(BaseConfigModifier):
@classmethod
def get_model_name(cls, config: dict) -> Tuple[str, str]:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(cfg, backend="trtllm")
args = validate_and_get_worker_args(worker_service, backend="trtllm")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME, DEFAULT_MODEL_NAME
worker_service = get_worker_service_from_config(cfg, backend="trtllm")
args = validate_and_get_worker_args(worker_service, backend="trtllm")
args = break_arguments(args)
return cls._get_model_name_and_path_from_args(args, DEFAULT_MODEL_NAME, logger)
return cls._get_model_name_and_path_from_args(args)
@classmethod
def get_port(cls, config: dict) -> int:
......
......@@ -19,11 +19,7 @@ from benchmarks.profiler.utils.config import (
validate_and_get_worker_args,
)
from benchmarks.profiler.utils.config_modifiers.protocol import BaseConfigModifier
from benchmarks.profiler.utils.defaults import (
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
EngineType,
)
from benchmarks.profiler.utils.defaults import DYNAMO_RUN_DEFAULT_PORT, EngineType
from dynamo.planner.defaults import SubComponentType
logger = logging.getLogger(__name__)
......@@ -282,17 +278,10 @@ class VllmV1ConfigModifier(BaseConfigModifier):
@classmethod
def get_model_name(cls, config: dict) -> Tuple[str, str]:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(cfg, backend="vllm")
args = validate_and_get_worker_args(worker_service, backend="vllm")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME, DEFAULT_MODEL_NAME
worker_service = get_worker_service_from_config(cfg, backend="vllm")
args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args)
return cls._get_model_name_and_path_from_args(args, DEFAULT_MODEL_NAME, logger)
return cls._get_model_name_and_path_from_args(args)
@classmethod
def get_port(cls, config: dict) -> int:
......
......@@ -15,7 +15,6 @@
from enum import Enum
DEFAULT_MODEL_NAME = "Qwen/Qwen3-0.6B"
DYNAMO_RUN_DEFAULT_PORT = 8000
# set a decode maximum concurrency due to limits of profiling tools
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment