Unverified Commit f26d8da7 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

chore: separate config modifier into multiple files (#3627)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent caaea7ad
...@@ -23,11 +23,8 @@ import numpy as np ...@@ -23,11 +23,8 @@ import numpy as np
import yaml import yaml
from benchmarks.profiler.utils.aiperf import benchmark_decode, benchmark_prefill from benchmarks.profiler.utils.aiperf import benchmark_decode, benchmark_prefill
from benchmarks.profiler.utils.config import ( from benchmarks.profiler.utils.config import generate_dgd_config_with_planner
CONFIG_MODIFIERS, from benchmarks.profiler.utils.config_modifiers import CONFIG_MODIFIERS
WORKER_COMPONENT_NAMES,
generate_dgd_config_with_planner,
)
from benchmarks.profiler.utils.estimate_perf import AIConfiguratorPerfEstimator from benchmarks.profiler.utils.estimate_perf import AIConfiguratorPerfEstimator
from benchmarks.profiler.utils.planner_utils import add_planner_arguments_to_parser from benchmarks.profiler.utils.planner_utils import add_planner_arguments_to_parser
from benchmarks.profiler.utils.plot import ( from benchmarks.profiler.utils.plot import (
...@@ -53,6 +50,7 @@ from deploy.utils.dynamo_deployment import ( ...@@ -53,6 +50,7 @@ from deploy.utils.dynamo_deployment import (
DynamoDeploymentClient, DynamoDeploymentClient,
cleanup_remaining_deployments, cleanup_remaining_deployments,
) )
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
......
...@@ -16,17 +16,12 @@ ...@@ -16,17 +16,12 @@
import json import json
import logging import logging
import math import math
import re
import shlex import shlex
from typing import Literal, Optional, Protocol from typing import Literal, Optional, Protocol
import yaml import yaml
from pydantic import BaseModel from pydantic import BaseModel
from benchmarks.profiler.utils.defaults import (
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
)
from benchmarks.profiler.utils.planner_utils import build_planner_args_from_namespace from benchmarks.profiler.utils.planner_utils import build_planner_args_from_namespace
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SubComponentType from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SubComponentType
...@@ -407,782 +402,6 @@ class ConfigModifierProtocol(Protocol): ...@@ -407,782 +402,6 @@ class ConfigModifierProtocol(Protocol):
... ...
class VllmV1ConfigModifier:
@classmethod
def convert_config(
cls,
config: dict,
target: Literal["prefill", "decode"],
is_moe_model: bool = False,
) -> dict:
if is_moe_model:
raise NotImplementedError(
"MoE model support is not implemented for VLLM backend"
)
cfg = Config.model_validate(config)
# set metadata name
cfg.metadata.name = "vllm-agg"
# disable planner
if "Planner" in cfg.spec.services:
del cfg.spec.services["Planner"]
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.DECODE
)
# convert prefill worker into decode worker
cfg.spec.services[decode_service_name] = cfg.spec.services[
prefill_service_name
]
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated mode (using decode worker for prefill-only)
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="vllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args)
# remove --is-prefill-worker flag
args.remove("--is-prefill-worker")
# disable prefix caching
if "--enable-prefix-caching" in args:
args.remove("--enable-prefix-caching")
if "--no-enable-prefix-caching" not in args:
args = append_argument(args, "--no-enable-prefix-caching")
worker_service.extraPodSpec.mainContainer.args = args
elif target == "decode":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.DECODE
)
# delete prefill worker
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated decode-only mode
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="vllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args)
# enable prefix caching
if "--enable-prefix-caching" not in args:
args = append_argument(args, "--enable-prefix-caching")
if "--no-enable-prefix-caching" in args:
args.remove("--no-enable-prefix-caching")
worker_service.extraPodSpec.mainContainer.args = args
# set num workers to 1
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.DECODE
)
decode_worker_config = cfg.spec.services[final_decode_service_name]
decode_worker_config.replicas = 1
return cfg.model_dump()
@classmethod
def set_config_tp_size(
cls,
config: dict,
tp_size: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
cfg, backend="vllm", sub_component_type=component_type
)
# Set up resources
setup_worker_service_resources(worker_service, tp_size)
# Get and validate args
args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args)
try:
idx = args.index("--tensor-parallel-size")
args[idx + 1] = str(tp_size)
except ValueError:
args = append_argument(args, ["--tensor-parallel-size", str(tp_size)])
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def set_config_tep_size(
cls,
config: dict,
tep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
raise NotImplementedError(
"TEP (Tensor Expert Parallelism) is not implemented for VLLM backend"
)
@classmethod
def set_config_dep_size(
cls,
config: dict,
dep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
raise NotImplementedError(
"DEP (Data Expert Parallelism) is not implemented for VLLM backend"
)
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(cfg, backend="vllm")
args = validate_and_get_worker_args(worker_service, backend="vllm")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = break_arguments(args)
for i, arg in enumerate(args):
if arg == "--model" and i + 1 < len(args):
return args[i + 1]
logger.warning(
f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
@classmethod
def get_port(cls, config: dict) -> int:
cfg = Config.model_validate(config)
frontend_service = cfg.spec.services.get("Frontend")
if (
not frontend_service
or not frontend_service.extraPodSpec
or not frontend_service.extraPodSpec.mainContainer
):
logger.warning(
f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = frontend_service.extraPodSpec.mainContainer.args
if not args:
logger.warning(
f"No args found in Frontend configuration, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = break_arguments(args)
try:
idx = args.index("--http-port")
return int(args[idx + 1])
except (ValueError, IndexError):
logger.warning(
f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
@classmethod
def get_kv_cache_size_from_dynamo_log(
cls, dynamo_log_fn: str, attention_dp_size: int = 1
) -> int:
try:
with open(dynamo_log_fn, "r") as f:
for line in f:
if "Maximum concurrency for" in line:
line = line.strip().split("Maximum concurrency for ")[1]
token_count = int(
line.split(" tokens per request: ")[0].replace(",", "")
)
concurrency = float(line.split(" tokens per request: ")[1][:-1])
logger.info(
f"Found KV cache info: {token_count} x {concurrency} = {int(token_count * concurrency)}"
)
return int(token_count * concurrency)
except Exception as e:
logger.warning(
f"Failed to parse KV cache size from line: {line}. Error: {e}"
)
return 0
class SGLangConfigModifier:
@classmethod
def convert_config(
cls,
config: dict,
target: Literal["prefill", "decode"],
is_moe_model: bool = False,
) -> dict:
cfg = Config.model_validate(config)
# set metadata name
cfg.metadata.name = "sglang-agg"
# disable planner
if "Planner" in cfg.spec.services:
del cfg.spec.services["Planner"]
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.DECODE
)
# convert prefill worker into decode worker
cfg.spec.services[decode_service_name] = cfg.spec.services[
prefill_service_name
]
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated mode (using decode worker for prefill-only)
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="sglang",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="sglang")
args = break_arguments(args)
# remove disagg flags
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-transfer-backend")
args = remove_valued_arguments(args, "--disaggregation-bootstrap-port")
# disable prefix caching
if "--disable-radix-cache" not in args:
args = append_argument(args, "--disable-radix-cache")
worker_service.extraPodSpec.mainContainer.args = args
elif target == "decode":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.DECODE
)
# delete prefill worker
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated decode-only mode
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="sglang",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="sglang")
args = break_arguments(args)
# remove disagg flags
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-transfer-backend")
args = remove_valued_arguments(args, "--disaggregation-bootstrap-port")
# enable prefix caching
if "--disable-radix-cache" in args:
args.remove("--disable-radix-cache")
if is_moe_model:
# need to use round_robin dp attention routing for MoE models to ensure kv reuse can skip prefill
if "--load-balance-method" in args:
idx = args.index("--load-balance-method")
args[idx + 1] = "round_robin"
else:
args = append_argument(
args, ["--load-balance-method", "round_robin"]
)
worker_service.extraPodSpec.mainContainer.args = args
# set num workers to 1
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.DECODE
)
decode_worker_config = cfg.spec.services[final_decode_service_name]
decode_worker_config.replicas = 1
return cfg.model_dump()
@classmethod
def set_config_tp_size(
cls,
config: dict,
tp_size: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
cfg, backend="sglang", sub_component_type=component_type
)
# Set up resources
setup_worker_service_resources(worker_service, tp_size)
# Get and validate args
args = validate_and_get_worker_args(worker_service, backend="sglang")
# Set --tp argument
args = set_argument_value(args, "--tp", str(tp_size))
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def set_config_tep_size(
cls,
config: dict,
tep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
cfg, backend="sglang", sub_component_type=component_type
)
# Set up resources with multinode configuration
setup_worker_service_resources(worker_service, tep_size, num_gpus_per_node)
# Get and validate args
args = validate_and_get_worker_args(worker_service, backend="sglang")
# 1. Set --tp=tep_size, if not present add it
args = set_argument_value(args, "--tp", str(tep_size))
# 2. Set --ep-size=tep_size, if not present add it
args = set_argument_value(args, "--ep-size", str(tep_size))
# 3. Remove --dp if present
args = remove_valued_arguments(args, "--dp")
# 4. Remove --enable-dp-attention if present
if "--enable-dp-attention" in args:
args.remove("--enable-dp-attention")
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def set_config_dep_size(
cls,
config: dict,
dep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
cfg, backend="sglang", sub_component_type=component_type
)
# Set up resources with multinode configuration
setup_worker_service_resources(worker_service, dep_size, num_gpus_per_node)
# Get and validate args
args = validate_and_get_worker_args(worker_service, backend="sglang")
# 1. Set --tp=dep_size
args = set_argument_value(args, "--tp", str(dep_size))
# 2. Set --dp=dep_size (data parallelism across experts)
args = set_argument_value(args, "--dp", str(dep_size))
# 3. Enable --enable-dp-attention
if "--enable-dp-attention" not in args:
args = append_argument(args, "--enable-dp-attention")
# 4. Set --ep-size=dep_size (expert parallelism size)
args = set_argument_value(args, "--ep-size", str(dep_size))
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(cfg, backend="sglang")
args = validate_and_get_worker_args(worker_service, backend="sglang")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = break_arguments(args)
for i, arg in enumerate(args):
if arg == "--served-model-name" and i + 1 < len(args):
return args[i + 1]
logger.warning(
f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
@classmethod
def get_port(cls, config: dict) -> int:
cfg = Config.model_validate(config)
frontend_service = cfg.spec.services.get("Frontend")
if (
not frontend_service
or not frontend_service.extraPodSpec
or not frontend_service.extraPodSpec.mainContainer
):
logger.warning(
f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = frontend_service.extraPodSpec.mainContainer.args
if not args:
logger.warning(
f"No args found in Frontend configuration, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = break_arguments(args)
try:
idx = args.index("--http-port")
return int(args[idx + 1])
except (ValueError, IndexError):
logger.warning(
f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
@classmethod
def get_kv_cache_size_from_dynamo_log(
cls, dynamo_log_fn: str, attention_dp_size: int = 1
) -> int:
try:
with open(dynamo_log_fn, "r") as f:
for line in f:
if "KV Cache is allocated" in line and "#tokens:" in line:
# Extract the number after "#tokens:"
match = re.search(r"#tokens:\s*(\d+)", line)
if match:
return int(match.group(1)) * attention_dp_size
except Exception as e:
logger.warning(f"Failed to parse KV cache size from log file. Error: {e}")
return 0
class TrtllmConfigModifier:
@classmethod
def convert_config(
cls,
config: dict,
target: Literal["prefill", "decode"],
is_moe_model: bool = False,
) -> dict:
if is_moe_model:
raise NotImplementedError(
"MoE model support is not implemented for TrtLLM backend"
)
cfg = Config.model_validate(config)
# set metadata name
cfg.metadata.name = "trtllm-agg"
# disable planner
if "Planner" in cfg.spec.services:
del cfg.spec.services["Planner"]
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.DECODE
)
# Convert to prefill-only aggregated setup
# Rename prefill worker to decode worker name
cfg.spec.services[decode_service_name] = cfg.spec.services[
prefill_service_name
]
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated mode (using decode worker for prefill-only)
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="trtllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="trtllm")
args = break_arguments(args)
# Remove disaggregation args
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-strategy")
# Keep the original extra-engine-args (prefill.yaml) which may contain user settings
# Check if user already has override-engine-args and merge with our changes
override_dict, args = parse_override_engine_args(args)
# Merge our overrides for converting prefill-only disagg to aggregated:
# - Disable enable_block_reuse (no KV reuse for prefill-only)
# - Enable overlap scheduler (disabled in prefill.yaml but needed for agg)
# - Remove cache_transceiver_config (not needed in agg mode)
if "kv_cache_config" not in override_dict:
override_dict["kv_cache_config"] = {}
override_dict["kv_cache_config"]["enable_block_reuse"] = False
override_dict[
"disable_overlap_scheduler"
] = False # Enable overlap scheduler for agg
override_dict[
"cache_transceiver_config"
] = None # Remove cache transceiver for agg
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = args
elif target == "decode":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.DECODE
)
# Convert to decode-only aggregated setup
# Remove prefill worker if exists
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated decode-only mode
cfg.spec.services[decode_service_name].subComponentType = "decode"
# Decode worker already has the correct name
worker_service = get_worker_service_from_config(
cfg,
backend="trtllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="trtllm")
args = break_arguments(args)
# Remove disaggregation args
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-strategy")
# Keep the original extra-engine-args (decode.yaml) which may contain user settings
# Check if user already has override-engine-args and merge with our changes
override_dict, args = parse_override_engine_args(args)
# Merge our overrides for converting decode-only disagg to aggregated:
# - Enable enable_block_reuse (to skip prefill in decode-only)
# - Remove cache_transceiver_config (not needed in agg mode)
if "kv_cache_config" not in override_dict:
override_dict["kv_cache_config"] = {}
override_dict["kv_cache_config"]["enable_block_reuse"] = True
override_dict[
"cache_transceiver_config"
] = None # Remove cache transceiver for agg
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = args
# Set num workers to 1
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.DECODE
)
worker_config = cfg.spec.services[final_decode_service_name]
worker_config.replicas = 1
return cfg.model_dump()
@classmethod
def set_config_tp_size(
cls,
config: dict,
tp_size: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
# Get the worker service using helper function
# This assumes convert_config has been called, so the service is named decode_worker_k8s_name
worker_service = get_worker_service_from_config(
cfg, backend="trtllm", sub_component_type=component_type
)
# Set up resources
setup_worker_service_resources(worker_service, tp_size)
# Validate and get args
args = validate_and_get_worker_args(worker_service, backend="trtllm")
# Break arguments to handle both joined strings and lists
args = break_arguments(args)
# For TRT-LLM, we need to update the override-engine-args
# to set the tensor_parallel_size
override_dict, args = parse_override_engine_args(args)
# Add/update tensor_parallel_size in the override
override_dict["tensor_parallel_size"] = tp_size
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def set_config_tep_size(
cls,
config: dict,
tep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
raise NotImplementedError(
"TEP (Tensor Expert Parallelism) is not implemented for TrtLLM backend"
)
@classmethod
def set_config_dep_size(
cls,
config: dict,
dep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
raise NotImplementedError(
"DEP (Data Expert Parallelism) is not implemented for TrtLLM backend"
)
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(cfg, backend="trtllm")
args = validate_and_get_worker_args(worker_service, backend="trtllm")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = break_arguments(args)
for i, arg in enumerate(args):
if arg == "--served-model-name" and i + 1 < len(args):
return args[i + 1]
logger.warning(
f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
@classmethod
def get_port(cls, config: dict) -> int:
cfg = Config.model_validate(config)
frontend_service = cfg.spec.services.get("Frontend")
if (
not frontend_service
or not frontend_service.extraPodSpec
or not frontend_service.extraPodSpec.mainContainer
):
logger.warning(
f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
# TRT-LLM frontend doesn't have args, it uses the default port
return DYNAMO_RUN_DEFAULT_PORT
@classmethod
def get_kv_cache_size_from_dynamo_log(
cls, dynamo_log_fn: str, attention_dp_size: int = 1
) -> int:
# TRT-LLM log parsing for KV cache size
# Format: [TensorRT-LLM][INFO] [MemUsageChange] Allocated XX GiB for max tokens in paged KV cache (XXXXXX).
try:
with open(dynamo_log_fn, "r") as f:
for line in f:
# Look for the specific TRT-LLM KV cache allocation log
if (
"Allocated" in line
and "for max tokens in paged KV cache" in line
):
# Extract the number in parentheses at the end
match = re.search(r"paged KV cache \((\d+)\)", line)
if match:
max_tokens = int(match.group(1))
logger.info(
f"Found TRT-LLM KV cache max tokens: {max_tokens}"
)
return max_tokens
except Exception as e:
logger.warning(f"Failed to parse KV cache size from log file. Error: {e}")
# Return a reasonable default if we couldn't find the KV cache size in logs
logger.warning(
"Could not find KV cache size in TRT-LLM logs, using default value of 100000"
)
return 100000 # Default fallback value for TRT-LLM
CONFIG_MODIFIERS: dict[str, type[ConfigModifierProtocol]] = {
"vllm": VllmV1ConfigModifier,
"sglang": SGLangConfigModifier,
"trtllm": TrtllmConfigModifier,
}
def generate_dgd_config_with_planner( def generate_dgd_config_with_planner(
config_path: str, config_path: str,
config_modifier, config_modifier,
...@@ -1293,7 +512,3 @@ def generate_dgd_config_with_planner( ...@@ -1293,7 +512,3 @@ def generate_dgd_config_with_planner(
config_dict["spec"]["services"]["Planner"] = planner_dict config_dict["spec"]["services"]["Planner"] = planner_dict
return config_dict return config_dict
# Re-export WORKER_COMPONENT_NAMES for profile_sla.py
__all__ = ["CONFIG_MODIFIERS", "WORKER_COMPONENT_NAMES"]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from benchmarks.profiler.utils.config import ConfigModifierProtocol
from benchmarks.profiler.utils.config_modifiers.sglang import SGLangConfigModifier
from benchmarks.profiler.utils.config_modifiers.trtllm import TrtllmConfigModifier
from benchmarks.profiler.utils.config_modifiers.vllm import VllmV1ConfigModifier
CONFIG_MODIFIERS: dict[str, type["ConfigModifierProtocol"]] = {
"vllm": VllmV1ConfigModifier,
"sglang": SGLangConfigModifier,
"trtllm": TrtllmConfigModifier,
}
__all__ = [
"VllmV1ConfigModifier",
"SGLangConfigModifier",
"TrtllmConfigModifier",
"CONFIG_MODIFIERS",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import re
from typing import Literal
from benchmarks.profiler.utils.config import (
Config,
append_argument,
break_arguments,
get_service_name_by_type,
get_worker_service_from_config,
remove_valued_arguments,
set_argument_value,
setup_worker_service_resources,
validate_and_get_worker_args,
)
from benchmarks.profiler.utils.defaults import (
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
)
from dynamo.planner.defaults import SubComponentType
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
class SGLangConfigModifier:
@classmethod
def convert_config(
cls,
config: dict,
target: Literal["prefill", "decode"],
is_moe_model: bool = False,
) -> dict:
cfg = Config.model_validate(config)
# set metadata name
cfg.metadata.name = "sglang-agg"
# disable planner
if "Planner" in cfg.spec.services:
del cfg.spec.services["Planner"]
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.DECODE
)
# convert prefill worker into decode worker
cfg.spec.services[decode_service_name] = cfg.spec.services[
prefill_service_name
]
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated mode (using decode worker for prefill-only)
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="sglang",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="sglang")
args = break_arguments(args)
# remove disagg flags
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-transfer-backend")
args = remove_valued_arguments(args, "--disaggregation-bootstrap-port")
# disable prefix caching
if "--disable-radix-cache" not in args:
args = append_argument(args, "--disable-radix-cache")
worker_service.extraPodSpec.mainContainer.args = args
elif target == "decode":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.DECODE
)
# delete prefill worker
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated decode-only mode
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="sglang",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="sglang")
args = break_arguments(args)
# remove disagg flags
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-transfer-backend")
args = remove_valued_arguments(args, "--disaggregation-bootstrap-port")
# enable prefix caching
if "--disable-radix-cache" in args:
args.remove("--disable-radix-cache")
if is_moe_model:
# need to use round_robin dp attention routing for MoE models to ensure kv reuse can skip prefill
if "--load-balance-method" in args:
idx = args.index("--load-balance-method")
args[idx + 1] = "round_robin"
else:
args = append_argument(
args, ["--load-balance-method", "round_robin"]
)
worker_service.extraPodSpec.mainContainer.args = args
# set num workers to 1
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.DECODE
)
decode_worker_config = cfg.spec.services[final_decode_service_name]
decode_worker_config.replicas = 1
return cfg.model_dump()
@classmethod
def set_config_tp_size(
cls,
config: dict,
tp_size: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
cfg, backend="sglang", sub_component_type=component_type
)
# Set up resources
setup_worker_service_resources(worker_service, tp_size)
# Get and validate args
args = validate_and_get_worker_args(worker_service, backend="sglang")
# Set --tp argument
args = set_argument_value(args, "--tp", str(tp_size))
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def set_config_tep_size(
cls,
config: dict,
tep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
cfg, backend="sglang", sub_component_type=component_type
)
# Set up resources with multinode configuration
setup_worker_service_resources(worker_service, tep_size, num_gpus_per_node)
# Get and validate args
args = validate_and_get_worker_args(worker_service, backend="sglang")
# 1. Set --tp=tep_size, if not present add it
args = set_argument_value(args, "--tp", str(tep_size))
# 2. Set --ep-size=tep_size, if not present add it
args = set_argument_value(args, "--ep-size", str(tep_size))
# 3. Remove --dp if present
args = remove_valued_arguments(args, "--dp")
# 4. Remove --enable-dp-attention if present
if "--enable-dp-attention" in args:
args.remove("--enable-dp-attention")
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def set_config_dep_size(
cls,
config: dict,
dep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
cfg, backend="sglang", sub_component_type=component_type
)
# Set up resources with multinode configuration
setup_worker_service_resources(worker_service, dep_size, num_gpus_per_node)
# Get and validate args
args = validate_and_get_worker_args(worker_service, backend="sglang")
# 1. Set --tp=dep_size
args = set_argument_value(args, "--tp", str(dep_size))
# 2. Set --dp=dep_size (data parallelism across experts)
args = set_argument_value(args, "--dp", str(dep_size))
# 3. Enable --enable-dp-attention
if "--enable-dp-attention" not in args:
args = append_argument(args, "--enable-dp-attention")
# 4. Set --ep-size=dep_size (expert parallelism size)
args = set_argument_value(args, "--ep-size", str(dep_size))
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(cfg, backend="sglang")
args = validate_and_get_worker_args(worker_service, backend="sglang")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = break_arguments(args)
for i, arg in enumerate(args):
if arg == "--served-model-name" and i + 1 < len(args):
return args[i + 1]
logger.warning(
f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
@classmethod
def get_port(cls, config: dict) -> int:
cfg = Config.model_validate(config)
frontend_service = cfg.spec.services.get("Frontend")
if (
not frontend_service
or not frontend_service.extraPodSpec
or not frontend_service.extraPodSpec.mainContainer
):
logger.warning(
f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = frontend_service.extraPodSpec.mainContainer.args
if not args:
logger.warning(
f"No args found in Frontend configuration, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = break_arguments(args)
try:
idx = args.index("--http-port")
return int(args[idx + 1])
except (ValueError, IndexError):
logger.warning(
f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
@classmethod
def get_kv_cache_size_from_dynamo_log(
cls, dynamo_log_fn: str, attention_dp_size: int = 1
) -> int:
try:
with open(dynamo_log_fn, "r") as f:
for line in f:
if "KV Cache is allocated" in line and "#tokens:" in line:
# Extract the number after "#tokens:"
match = re.search(r"#tokens:\s*(\d+)", line)
if match:
return int(match.group(1)) * attention_dp_size
except Exception as e:
logger.warning(f"Failed to parse KV cache size from log file. Error: {e}")
return 0
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import re
from typing import Literal
from benchmarks.profiler.utils.config import (
Config,
append_argument,
break_arguments,
get_service_name_by_type,
get_worker_service_from_config,
parse_override_engine_args,
remove_valued_arguments,
setup_worker_service_resources,
validate_and_get_worker_args,
)
from benchmarks.profiler.utils.defaults import (
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
)
from dynamo.planner.defaults import SubComponentType
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
class TrtllmConfigModifier:
@classmethod
def convert_config(
cls,
config: dict,
target: Literal["prefill", "decode"],
is_moe_model: bool = False,
) -> dict:
if is_moe_model:
raise NotImplementedError(
"MoE model support is not implemented for TrtLLM backend"
)
cfg = Config.model_validate(config)
# set metadata name
cfg.metadata.name = "trtllm-agg"
# disable planner
if "Planner" in cfg.spec.services:
del cfg.spec.services["Planner"]
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.DECODE
)
# Convert to prefill-only aggregated setup
# Rename prefill worker to decode worker name
cfg.spec.services[decode_service_name] = cfg.spec.services[
prefill_service_name
]
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated mode (using decode worker for prefill-only)
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="trtllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="trtllm")
args = break_arguments(args)
# Remove disaggregation args
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-strategy")
# Keep the original extra-engine-args (prefill.yaml) which may contain user settings
# Check if user already has override-engine-args and merge with our changes
override_dict, args = parse_override_engine_args(args)
# Merge our overrides for converting prefill-only disagg to aggregated:
# - Disable enable_block_reuse (no KV reuse for prefill-only)
# - Enable overlap scheduler (disabled in prefill.yaml but needed for agg)
# - Remove cache_transceiver_config (not needed in agg mode)
if "kv_cache_config" not in override_dict:
override_dict["kv_cache_config"] = {}
override_dict["kv_cache_config"]["enable_block_reuse"] = False
override_dict[
"disable_overlap_scheduler"
] = False # Enable overlap scheduler for agg
override_dict[
"cache_transceiver_config"
] = None # Remove cache transceiver for agg
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = args
elif target == "decode":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.DECODE
)
# Convert to decode-only aggregated setup
# Remove prefill worker if exists
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated decode-only mode
cfg.spec.services[decode_service_name].subComponentType = "decode"
# Decode worker already has the correct name
worker_service = get_worker_service_from_config(
cfg,
backend="trtllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="trtllm")
args = break_arguments(args)
# Remove disaggregation args
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-strategy")
# Keep the original extra-engine-args (decode.yaml) which may contain user settings
# Check if user already has override-engine-args and merge with our changes
override_dict, args = parse_override_engine_args(args)
# Merge our overrides for converting decode-only disagg to aggregated:
# - Enable enable_block_reuse (to skip prefill in decode-only)
# - Remove cache_transceiver_config (not needed in agg mode)
if "kv_cache_config" not in override_dict:
override_dict["kv_cache_config"] = {}
override_dict["kv_cache_config"]["enable_block_reuse"] = True
override_dict[
"cache_transceiver_config"
] = None # Remove cache transceiver for agg
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = args
# Set num workers to 1
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.DECODE
)
worker_config = cfg.spec.services[final_decode_service_name]
worker_config.replicas = 1
return cfg.model_dump()
@classmethod
def set_config_tp_size(
cls,
config: dict,
tp_size: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
# Get the worker service using helper function
# This assumes convert_config has been called, so the service is named decode_worker_k8s_name
worker_service = get_worker_service_from_config(
cfg, backend="trtllm", sub_component_type=component_type
)
# Set up resources
setup_worker_service_resources(worker_service, tp_size)
# Validate and get args
args = validate_and_get_worker_args(worker_service, backend="trtllm")
# Break arguments to handle both joined strings and lists
args = break_arguments(args)
# For TRT-LLM, we need to update the override-engine-args
# to set the tensor_parallel_size
override_dict, args = parse_override_engine_args(args)
# Add/update tensor_parallel_size in the override
override_dict["tensor_parallel_size"] = tp_size
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def set_config_tep_size(
cls,
config: dict,
tep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
raise NotImplementedError(
"TEP (Tensor Expert Parallelism) is not implemented for TrtLLM backend"
)
@classmethod
def set_config_dep_size(
cls,
config: dict,
dep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
raise NotImplementedError(
"DEP (Data Expert Parallelism) is not implemented for TrtLLM backend"
)
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(cfg, backend="trtllm")
args = validate_and_get_worker_args(worker_service, backend="trtllm")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = break_arguments(args)
for i, arg in enumerate(args):
if arg == "--served-model-name" and i + 1 < len(args):
return args[i + 1]
logger.warning(
f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
@classmethod
def get_port(cls, config: dict) -> int:
cfg = Config.model_validate(config)
frontend_service = cfg.spec.services.get("Frontend")
if (
not frontend_service
or not frontend_service.extraPodSpec
or not frontend_service.extraPodSpec.mainContainer
):
logger.warning(
f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
# TRT-LLM frontend doesn't have args, it uses the default port
return DYNAMO_RUN_DEFAULT_PORT
@classmethod
def get_kv_cache_size_from_dynamo_log(
cls, dynamo_log_fn: str, attention_dp_size: int = 1
) -> int:
# TRT-LLM log parsing for KV cache size
# Format: [TensorRT-LLM][INFO] [MemUsageChange] Allocated XX GiB for max tokens in paged KV cache (XXXXXX).
try:
with open(dynamo_log_fn, "r") as f:
for line in f:
# Look for the specific TRT-LLM KV cache allocation log
if (
"Allocated" in line
and "for max tokens in paged KV cache" in line
):
# Extract the number in parentheses at the end
match = re.search(r"paged KV cache \((\d+)\)", line)
if match:
max_tokens = int(match.group(1))
logger.info(
f"Found TRT-LLM KV cache max tokens: {max_tokens}"
)
return max_tokens
except Exception as e:
logger.warning(f"Failed to parse KV cache size from log file. Error: {e}")
# Return a reasonable default if we couldn't find the KV cache size in logs
logger.warning(
"Could not find KV cache size in TRT-LLM logs, using default value of 100000"
)
return 100000 # Default fallback value for TRT-LLM
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Literal
from benchmarks.profiler.utils.config import (
Config,
append_argument,
break_arguments,
get_service_name_by_type,
get_worker_service_from_config,
setup_worker_service_resources,
validate_and_get_worker_args,
)
from benchmarks.profiler.utils.defaults import (
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
)
from dynamo.planner.defaults import SubComponentType
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
class VllmV1ConfigModifier:
@classmethod
def convert_config(
cls,
config: dict,
target: Literal["prefill", "decode"],
is_moe_model: bool = False,
) -> dict:
if is_moe_model:
raise NotImplementedError(
"MoE model support is not implemented for VLLM backend"
)
cfg = Config.model_validate(config)
# set metadata name
cfg.metadata.name = "vllm-agg"
# disable planner
if "Planner" in cfg.spec.services:
del cfg.spec.services["Planner"]
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.DECODE
)
# convert prefill worker into decode worker
cfg.spec.services[decode_service_name] = cfg.spec.services[
prefill_service_name
]
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated mode (using decode worker for prefill-only)
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="vllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args)
# remove --is-prefill-worker flag
args.remove("--is-prefill-worker")
# disable prefix caching
if "--enable-prefix-caching" in args:
args.remove("--enable-prefix-caching")
if "--no-enable-prefix-caching" not in args:
args = append_argument(args, "--no-enable-prefix-caching")
worker_service.extraPodSpec.mainContainer.args = args
elif target == "decode":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.DECODE
)
# delete prefill worker
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated decode-only mode
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="vllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args)
# enable prefix caching
if "--enable-prefix-caching" not in args:
args = append_argument(args, "--enable-prefix-caching")
if "--no-enable-prefix-caching" in args:
args.remove("--no-enable-prefix-caching")
worker_service.extraPodSpec.mainContainer.args = args
# set num workers to 1
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.DECODE
)
decode_worker_config = cfg.spec.services[final_decode_service_name]
decode_worker_config.replicas = 1
return cfg.model_dump()
@classmethod
def set_config_tp_size(
cls,
config: dict,
tp_size: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
cfg, backend="vllm", sub_component_type=component_type
)
# Set up resources
setup_worker_service_resources(worker_service, tp_size)
# Get and validate args
args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args)
try:
idx = args.index("--tensor-parallel-size")
args[idx + 1] = str(tp_size)
except ValueError:
args = append_argument(args, ["--tensor-parallel-size", str(tp_size)])
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def set_config_tep_size(
cls,
config: dict,
tep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
raise NotImplementedError(
"TEP (Tensor Expert Parallelism) is not implemented for VLLM backend"
)
@classmethod
def set_config_dep_size(
cls,
config: dict,
dep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
raise NotImplementedError(
"DEP (Data Expert Parallelism) is not implemented for VLLM backend"
)
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(cfg, backend="vllm")
args = validate_and_get_worker_args(worker_service, backend="vllm")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = break_arguments(args)
for i, arg in enumerate(args):
if arg == "--model" and i + 1 < len(args):
return args[i + 1]
logger.warning(
f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
@classmethod
def get_port(cls, config: dict) -> int:
cfg = Config.model_validate(config)
frontend_service = cfg.spec.services.get("Frontend")
if (
not frontend_service
or not frontend_service.extraPodSpec
or not frontend_service.extraPodSpec.mainContainer
):
logger.warning(
f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = frontend_service.extraPodSpec.mainContainer.args
if not args:
logger.warning(
f"No args found in Frontend configuration, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = break_arguments(args)
try:
idx = args.index("--http-port")
return int(args[idx + 1])
except (ValueError, IndexError):
logger.warning(
f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
@classmethod
def get_kv_cache_size_from_dynamo_log(
cls, dynamo_log_fn: str, attention_dp_size: int = 1
) -> int:
try:
with open(dynamo_log_fn, "r") as f:
for line in f:
if "Maximum concurrency for" in line:
line = line.strip().split("Maximum concurrency for ")[1]
token_count = int(
line.split(" tokens per request: ")[0].replace(",", "")
)
concurrency = float(line.split(" tokens per request: ")[1][:-1])
logger.info(
f"Found KV cache info: {token_count} x {concurrency} = {int(token_count * concurrency)}"
)
return int(token_count * concurrency)
except Exception as e:
logger.warning(
f"Failed to parse KV cache size from line: {line}. Error: {e}"
)
return 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment