Unverified Commit 2b199540 authored by hhzhang16's avatar hhzhang16 Committed by GitHub
Browse files

fix: check --served-model-name first before --model/--model-path (#5881)


Signed-off-by: default avatarHannah Zhang <hannahz@nvidia.com>
parent 6e568d45
...@@ -160,7 +160,7 @@ async def run_profile(args): ...@@ -160,7 +160,7 @@ async def run_profile(args):
logger.info(f"Profiling GPU counts: {profile_num_gpus}") logger.info(f"Profiling GPU counts: {profile_num_gpus}")
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
model_name = config_modifier.get_model_name(config) model_name, model_path = config_modifier.get_model_name(config)
# Determine sweep max context length: allow user-provided cap to override model's if smaller # Determine sweep max context length: allow user-provided cap to override model's if smaller
use_specified_max_context_len = getattr(args, "max_context_length", None) use_specified_max_context_len = getattr(args, "max_context_length", None)
...@@ -302,7 +302,7 @@ async def run_profile(args): ...@@ -302,7 +302,7 @@ async def run_profile(args):
args.isl, args.isl,
ai_perf_artifact_dir, ai_perf_artifact_dir,
model_name, model_name,
model_name, model_path,
base_url, base_url,
attention_dp_size=mapping.get_attn_dp_size(), attention_dp_size=mapping.get_attn_dp_size(),
) )
...@@ -449,7 +449,7 @@ async def run_profile(args): ...@@ -449,7 +449,7 @@ async def run_profile(args):
num_request, num_request,
ai_perf_artifact_dir, ai_perf_artifact_dir,
model_name, model_name,
model_name, model_path,
base_url=base_url, base_url=base_url,
num_gpus=num_gpus, num_gpus=num_gpus,
attention_dp_size=mapping.get_attn_dp_size(), attention_dp_size=mapping.get_attn_dp_size(),
...@@ -631,7 +631,7 @@ async def run_profile(args): ...@@ -631,7 +631,7 @@ async def run_profile(args):
profile_prefill( profile_prefill(
work_dir, work_dir,
model_name, model_name,
model_name, model_path,
base_url, base_url,
best_prefill_gpus, best_prefill_gpus,
sweep_max_context_length, sweep_max_context_length,
...@@ -718,7 +718,7 @@ async def run_profile(args): ...@@ -718,7 +718,7 @@ async def run_profile(args):
profile_decode( profile_decode(
work_dir, work_dir,
model_name, model_name,
model_name, model_path,
base_url, base_url,
best_decode_gpus, best_decode_gpus,
max_kv_tokens, max_kv_tokens,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Protocol from typing import Any, Protocol, Tuple
from benchmarks.profiler.utils.config import ( from benchmarks.profiler.utils.config import (
Config, Config,
...@@ -69,7 +69,7 @@ class ConfigModifierProtocol(Protocol): ...@@ -69,7 +69,7 @@ class ConfigModifierProtocol(Protocol):
... ...
@classmethod @classmethod
def get_model_name(cls, config: dict) -> str: def get_model_name(cls, config: dict) -> Tuple[str, str]:
... ...
@classmethod @classmethod
...@@ -132,6 +132,51 @@ class BaseConfigModifier: ...@@ -132,6 +132,51 @@ class BaseConfigModifier:
WORKER_MODEL_PATH_ARG: str = "--model-path" WORKER_MODEL_PATH_ARG: str = "--model-path"
WORKER_SERVED_MODEL_NAME_ARG: str = "--served-model-name" WORKER_SERVED_MODEL_NAME_ARG: str = "--served-model-name"
@classmethod
def _get_model_name_and_path_from_args(
cls, args: list[str], default_model_name: str, logger
) -> tuple[str, str]:
"""
Extract model name and path from worker args.
Checks --served-model-name first (API name), then falls back to
backend-specific path argument (--model-path or --model).
Args:
args: Broken argument list
default_model_name: Default to use if neither arg found
logger: Logger instance for warnings
Returns:
Tuple of (model_name, model_path)
"""
model_name = default_model_name
# Check for --served-model-name first (API model name)
for i, arg in enumerate(args):
if arg == cls.WORKER_SERVED_MODEL_NAME_ARG and i + 1 < len(args):
model_name = args[i + 1]
break
# Check for backend-specific path argument
model_path = model_name
for i, arg in enumerate(args):
if arg == cls.WORKER_MODEL_PATH_ARG and i + 1 < len(args):
model_path = args[i + 1]
break
# If model_name not found, use model_path as model_name
if model_name == default_model_name and model_path != default_model_name:
model_name = model_path
# Warn if neither argument was found
if model_name == default_model_name and model_path == default_model_name:
logger.warning(
f"Model name not found in configuration args, using default model name: {default_model_name}"
)
return model_name, model_path
@classmethod @classmethod
def _normalize_model_path(cls, pvc_mount_path: str, pvc_path: str) -> str: def _normalize_model_path(cls, pvc_mount_path: str, pvc_path: str) -> str:
mount = (pvc_mount_path or "").rstrip("/") mount = (pvc_mount_path or "").rstrip("/")
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import logging import logging
import re import re
from typing import Tuple
import yaml import yaml
...@@ -279,7 +280,7 @@ class SGLangConfigModifier(BaseConfigModifier): ...@@ -279,7 +280,7 @@ class SGLangConfigModifier(BaseConfigModifier):
return cfg.model_dump() return cfg.model_dump()
@classmethod @classmethod
def get_model_name(cls, config: dict) -> str: def get_model_name(cls, config: dict) -> Tuple[str, str]:
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
try: try:
worker_service = get_worker_service_from_config(cfg, backend="sglang") worker_service = get_worker_service_from_config(cfg, backend="sglang")
...@@ -288,23 +289,10 @@ class SGLangConfigModifier(BaseConfigModifier): ...@@ -288,23 +289,10 @@ class SGLangConfigModifier(BaseConfigModifier):
logger.warning( logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}" f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
) )
return DEFAULT_MODEL_NAME return DEFAULT_MODEL_NAME, DEFAULT_MODEL_NAME
args = break_arguments(args) args = break_arguments(args)
# Check for --model-path first (primary argument for SGLang) return cls._get_model_name_and_path_from_args(args, DEFAULT_MODEL_NAME, logger)
for i, arg in enumerate(args):
if arg == "--model-path" and i + 1 < len(args):
return args[i + 1]
# Fall back to --served-model-name if --model-path not found
for i, arg in enumerate(args):
if arg == "--served-model-name" and i + 1 < len(args):
return args[i + 1]
logger.warning(
f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
@classmethod @classmethod
def get_port(cls, config: dict) -> int: def get_port(cls, config: dict) -> int:
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import json import json
import logging import logging
import re import re
from typing import Tuple
import yaml import yaml
...@@ -253,7 +254,7 @@ class TrtllmConfigModifier(BaseConfigModifier): ...@@ -253,7 +254,7 @@ class TrtllmConfigModifier(BaseConfigModifier):
) )
@classmethod @classmethod
def get_model_name(cls, config: dict) -> str: def get_model_name(cls, config: dict) -> Tuple[str, str]:
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
try: try:
worker_service = get_worker_service_from_config(cfg, backend="trtllm") worker_service = get_worker_service_from_config(cfg, backend="trtllm")
...@@ -262,17 +263,10 @@ class TrtllmConfigModifier(BaseConfigModifier): ...@@ -262,17 +263,10 @@ class TrtllmConfigModifier(BaseConfigModifier):
logger.warning( logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}" f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
) )
return DEFAULT_MODEL_NAME return DEFAULT_MODEL_NAME, DEFAULT_MODEL_NAME
args = break_arguments(args) args = break_arguments(args)
for i, arg in enumerate(args): return cls._get_model_name_and_path_from_args(args, DEFAULT_MODEL_NAME, logger)
if arg == "--served-model-name" and i + 1 < len(args):
return args[i + 1]
logger.warning(
f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
@classmethod @classmethod
def get_port(cls, config: dict) -> int: def get_port(cls, config: dict) -> int:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
from typing import Tuple
import yaml import yaml
...@@ -279,7 +280,7 @@ class VllmV1ConfigModifier(BaseConfigModifier): ...@@ -279,7 +280,7 @@ class VllmV1ConfigModifier(BaseConfigModifier):
return cfg.model_dump() return cfg.model_dump()
@classmethod @classmethod
def get_model_name(cls, config: dict) -> str: def get_model_name(cls, config: dict) -> Tuple[str, str]:
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
try: try:
worker_service = get_worker_service_from_config(cfg, backend="vllm") worker_service = get_worker_service_from_config(cfg, backend="vllm")
...@@ -288,17 +289,10 @@ class VllmV1ConfigModifier(BaseConfigModifier): ...@@ -288,17 +289,10 @@ class VllmV1ConfigModifier(BaseConfigModifier):
logger.warning( logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}" f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
) )
return DEFAULT_MODEL_NAME return DEFAULT_MODEL_NAME, DEFAULT_MODEL_NAME
args = break_arguments(args) args = break_arguments(args)
for i, arg in enumerate(args): return cls._get_model_name_and_path_from_args(args, DEFAULT_MODEL_NAME, logger)
if arg == "--model" and i + 1 < len(args):
return args[i + 1]
logger.warning(
f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
@classmethod @classmethod
def get_port(cls, config: dict) -> int: def get_port(cls, config: dict) -> int:
......
...@@ -80,8 +80,8 @@ def auto_generate_search_space(args: argparse.Namespace) -> None: ...@@ -80,8 +80,8 @@ def auto_generate_search_space(args: argparse.Namespace) -> None:
model_name_or_path = args.model model_name_or_path = args.model
else: else:
# get the model name from config # get the model name from config
args.model = config_modifier.get_model_name(config) args.model, args.model_path = config_modifier.get_model_name(config)
model_name_or_path = args.model model_name_or_path = args.model_path
logger.info(f"Getting model info for {args.model} at {model_name_or_path}...") logger.info(f"Getting model info for {args.model} at {model_name_or_path}...")
try: try:
model_info = get_model_info(model_name_or_path) model_info = get_model_info(model_name_or_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment