Unverified Commit 7e4bc716 authored by hhzhang16's avatar hhzhang16 Committed by GitHub
Browse files

feat: remove default model name in Profiler; validate for one of served model...


feat: remove default model name in Profiler; validate for one of served model name and model path in Profiler (#5950)
Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
parent 3a2885c8
...@@ -133,9 +133,7 @@ class BaseConfigModifier: ...@@ -133,9 +133,7 @@ class BaseConfigModifier:
WORKER_SERVED_MODEL_NAME_ARG: str = "--served-model-name" WORKER_SERVED_MODEL_NAME_ARG: str = "--served-model-name"
@classmethod @classmethod
def _get_model_name_and_path_from_args( def _get_model_name_and_path_from_args(cls, args: list[str]) -> Tuple[str, str]:
cls, args: list[str], default_model_name: str, logger
) -> tuple[str, str]:
""" """
Extract model name and path from worker args. Extract model name and path from worker args.
...@@ -144,14 +142,14 @@ class BaseConfigModifier: ...@@ -144,14 +142,14 @@ class BaseConfigModifier:
Args: Args:
args: Broken argument list args: Broken argument list
default_model_name: Default to use if neither arg found
logger: Logger instance for warnings
Returns: Returns:
Tuple of (model_name, model_path) Tuple of (model_name, model_path)
"""
model_name = default_model_name
Raises:
ValueError: If neither --served-model-name nor model path arg is found
"""
model_name = None
# Check for --served-model-name first (API model name) # Check for --served-model-name first (API model name)
for i, arg in enumerate(args): for i, arg in enumerate(args):
if arg == cls.WORKER_SERVED_MODEL_NAME_ARG and i + 1 < len(args): if arg == cls.WORKER_SERVED_MODEL_NAME_ARG and i + 1 < len(args):
...@@ -159,22 +157,26 @@ class BaseConfigModifier: ...@@ -159,22 +157,26 @@ class BaseConfigModifier:
break break
# Check for backend-specific path argument # Check for backend-specific path argument
model_path = model_name model_path = None
for i, arg in enumerate(args): for i, arg in enumerate(args):
if arg == cls.WORKER_MODEL_PATH_ARG and i + 1 < len(args): if arg == cls.WORKER_MODEL_PATH_ARG and i + 1 < len(args):
model_path = args[i + 1] model_path = args[i + 1]
break break
# If model_name not found, use model_path as model_name # Require at least one to be specified
if model_name == default_model_name and model_path != default_model_name: if model_name is None and model_path is None:
model_name = model_path raise ValueError(
f"Cannot determine model: neither {cls.WORKER_MODEL_PATH_ARG} nor "
# Warn if neither argument was found f"{cls.WORKER_SERVED_MODEL_NAME_ARG} found in worker configuration. "
if model_name == default_model_name and model_path == default_model_name: f"Please specify a model name/path in your config."
logger.warning(
f"Model name not found in configuration args, using default model name: {default_model_name}"
) )
# If only one is specified, use it for both
if model_path is None:
model_path = model_name
elif model_name is None:
model_name = model_path
return model_name, model_path return model_name, model_path
@classmethod @classmethod
......
...@@ -20,11 +20,7 @@ from benchmarks.profiler.utils.config import ( ...@@ -20,11 +20,7 @@ from benchmarks.profiler.utils.config import (
validate_and_get_worker_args, validate_and_get_worker_args,
) )
from benchmarks.profiler.utils.config_modifiers.protocol import BaseConfigModifier from benchmarks.profiler.utils.config_modifiers.protocol import BaseConfigModifier
from benchmarks.profiler.utils.defaults import ( from benchmarks.profiler.utils.defaults import DYNAMO_RUN_DEFAULT_PORT, EngineType
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
EngineType,
)
from dynamo.planner.defaults import SubComponentType from dynamo.planner.defaults import SubComponentType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -282,17 +278,10 @@ class SGLangConfigModifier(BaseConfigModifier): ...@@ -282,17 +278,10 @@ class SGLangConfigModifier(BaseConfigModifier):
@classmethod @classmethod
def get_model_name(cls, config: dict) -> Tuple[str, str]: def get_model_name(cls, config: dict) -> Tuple[str, str]:
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
try: worker_service = get_worker_service_from_config(cfg, backend="sglang")
worker_service = get_worker_service_from_config(cfg, backend="sglang") args = validate_and_get_worker_args(worker_service, backend="sglang")
args = validate_and_get_worker_args(worker_service, backend="sglang")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME, DEFAULT_MODEL_NAME
args = break_arguments(args) args = break_arguments(args)
return cls._get_model_name_and_path_from_args(args, DEFAULT_MODEL_NAME, logger) return cls._get_model_name_and_path_from_args(args)
@classmethod @classmethod
def get_port(cls, config: dict) -> int: def get_port(cls, config: dict) -> int:
......
...@@ -21,11 +21,7 @@ from benchmarks.profiler.utils.config import ( ...@@ -21,11 +21,7 @@ from benchmarks.profiler.utils.config import (
validate_and_get_worker_args, validate_and_get_worker_args,
) )
from benchmarks.profiler.utils.config_modifiers.protocol import BaseConfigModifier from benchmarks.profiler.utils.config_modifiers.protocol import BaseConfigModifier
from benchmarks.profiler.utils.defaults import ( from benchmarks.profiler.utils.defaults import DYNAMO_RUN_DEFAULT_PORT, EngineType
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
EngineType,
)
from dynamo.planner.defaults import SubComponentType from dynamo.planner.defaults import SubComponentType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -256,17 +252,10 @@ class TrtllmConfigModifier(BaseConfigModifier): ...@@ -256,17 +252,10 @@ class TrtllmConfigModifier(BaseConfigModifier):
@classmethod @classmethod
def get_model_name(cls, config: dict) -> Tuple[str, str]: def get_model_name(cls, config: dict) -> Tuple[str, str]:
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
try: worker_service = get_worker_service_from_config(cfg, backend="trtllm")
worker_service = get_worker_service_from_config(cfg, backend="trtllm") args = validate_and_get_worker_args(worker_service, backend="trtllm")
args = validate_and_get_worker_args(worker_service, backend="trtllm")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME, DEFAULT_MODEL_NAME
args = break_arguments(args) args = break_arguments(args)
return cls._get_model_name_and_path_from_args(args, DEFAULT_MODEL_NAME, logger) return cls._get_model_name_and_path_from_args(args)
@classmethod @classmethod
def get_port(cls, config: dict) -> int: def get_port(cls, config: dict) -> int:
......
...@@ -19,11 +19,7 @@ from benchmarks.profiler.utils.config import ( ...@@ -19,11 +19,7 @@ from benchmarks.profiler.utils.config import (
validate_and_get_worker_args, validate_and_get_worker_args,
) )
from benchmarks.profiler.utils.config_modifiers.protocol import BaseConfigModifier from benchmarks.profiler.utils.config_modifiers.protocol import BaseConfigModifier
from benchmarks.profiler.utils.defaults import ( from benchmarks.profiler.utils.defaults import DYNAMO_RUN_DEFAULT_PORT, EngineType
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
EngineType,
)
from dynamo.planner.defaults import SubComponentType from dynamo.planner.defaults import SubComponentType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -282,17 +278,10 @@ class VllmV1ConfigModifier(BaseConfigModifier): ...@@ -282,17 +278,10 @@ class VllmV1ConfigModifier(BaseConfigModifier):
@classmethod @classmethod
def get_model_name(cls, config: dict) -> Tuple[str, str]: def get_model_name(cls, config: dict) -> Tuple[str, str]:
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
try: worker_service = get_worker_service_from_config(cfg, backend="vllm")
worker_service = get_worker_service_from_config(cfg, backend="vllm") args = validate_and_get_worker_args(worker_service, backend="vllm")
args = validate_and_get_worker_args(worker_service, backend="vllm")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME, DEFAULT_MODEL_NAME
args = break_arguments(args) args = break_arguments(args)
return cls._get_model_name_and_path_from_args(args, DEFAULT_MODEL_NAME, logger) return cls._get_model_name_and_path_from_args(args)
@classmethod @classmethod
def get_port(cls, config: dict) -> int: def get_port(cls, config: dict) -> int:
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from enum import Enum from enum import Enum
DEFAULT_MODEL_NAME = "Qwen/Qwen3-0.6B"
DYNAMO_RUN_DEFAULT_PORT = 8000 DYNAMO_RUN_DEFAULT_PORT = 8000
# set a decode maximum concurrency due to limits of profiling tools # set a decode maximum concurrency due to limits of profiling tools
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment