Unverified Commit f26d8da7 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

chore: separate config modifier into multiple files (#3627)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent caaea7ad
...@@ -23,11 +23,8 @@ import numpy as np ...@@ -23,11 +23,8 @@ import numpy as np
import yaml import yaml
from benchmarks.profiler.utils.aiperf import benchmark_decode, benchmark_prefill from benchmarks.profiler.utils.aiperf import benchmark_decode, benchmark_prefill
from benchmarks.profiler.utils.config import ( from benchmarks.profiler.utils.config import generate_dgd_config_with_planner
CONFIG_MODIFIERS, from benchmarks.profiler.utils.config_modifiers import CONFIG_MODIFIERS
WORKER_COMPONENT_NAMES,
generate_dgd_config_with_planner,
)
from benchmarks.profiler.utils.estimate_perf import AIConfiguratorPerfEstimator from benchmarks.profiler.utils.estimate_perf import AIConfiguratorPerfEstimator
from benchmarks.profiler.utils.planner_utils import add_planner_arguments_to_parser from benchmarks.profiler.utils.planner_utils import add_planner_arguments_to_parser
from benchmarks.profiler.utils.plot import ( from benchmarks.profiler.utils.plot import (
...@@ -53,6 +50,7 @@ from deploy.utils.dynamo_deployment import ( ...@@ -53,6 +50,7 @@ from deploy.utils.dynamo_deployment import (
DynamoDeploymentClient, DynamoDeploymentClient,
cleanup_remaining_deployments, cleanup_remaining_deployments,
) )
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
......
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from benchmarks.profiler.utils.config import ConfigModifierProtocol
from benchmarks.profiler.utils.config_modifiers.sglang import SGLangConfigModifier
from benchmarks.profiler.utils.config_modifiers.trtllm import TrtllmConfigModifier
from benchmarks.profiler.utils.config_modifiers.vllm import VllmV1ConfigModifier
CONFIG_MODIFIERS: dict[str, type["ConfigModifierProtocol"]] = {
"vllm": VllmV1ConfigModifier,
"sglang": SGLangConfigModifier,
"trtllm": TrtllmConfigModifier,
}
__all__ = [
"VllmV1ConfigModifier",
"SGLangConfigModifier",
"TrtllmConfigModifier",
"CONFIG_MODIFIERS",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import re
from typing import Literal
from benchmarks.profiler.utils.config import (
Config,
append_argument,
break_arguments,
get_service_name_by_type,
get_worker_service_from_config,
remove_valued_arguments,
set_argument_value,
setup_worker_service_resources,
validate_and_get_worker_args,
)
from benchmarks.profiler.utils.defaults import (
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
)
from dynamo.planner.defaults import SubComponentType
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
class SGLangConfigModifier:
@classmethod
def convert_config(
cls,
config: dict,
target: Literal["prefill", "decode"],
is_moe_model: bool = False,
) -> dict:
cfg = Config.model_validate(config)
# set metadata name
cfg.metadata.name = "sglang-agg"
# disable planner
if "Planner" in cfg.spec.services:
del cfg.spec.services["Planner"]
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.DECODE
)
# convert prefill worker into decode worker
cfg.spec.services[decode_service_name] = cfg.spec.services[
prefill_service_name
]
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated mode (using decode worker for prefill-only)
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="sglang",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="sglang")
args = break_arguments(args)
# remove disagg flags
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-transfer-backend")
args = remove_valued_arguments(args, "--disaggregation-bootstrap-port")
# disable prefix caching
if "--disable-radix-cache" not in args:
args = append_argument(args, "--disable-radix-cache")
worker_service.extraPodSpec.mainContainer.args = args
elif target == "decode":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.DECODE
)
# delete prefill worker
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated decode-only mode
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="sglang",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="sglang")
args = break_arguments(args)
# remove disagg flags
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-transfer-backend")
args = remove_valued_arguments(args, "--disaggregation-bootstrap-port")
# enable prefix caching
if "--disable-radix-cache" in args:
args.remove("--disable-radix-cache")
if is_moe_model:
# need to use round_robin dp attention routing for MoE models to ensure kv reuse can skip prefill
if "--load-balance-method" in args:
idx = args.index("--load-balance-method")
args[idx + 1] = "round_robin"
else:
args = append_argument(
args, ["--load-balance-method", "round_robin"]
)
worker_service.extraPodSpec.mainContainer.args = args
# set num workers to 1
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg, "sglang", SubComponentType.DECODE
)
decode_worker_config = cfg.spec.services[final_decode_service_name]
decode_worker_config.replicas = 1
return cfg.model_dump()
@classmethod
def set_config_tp_size(
cls,
config: dict,
tp_size: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
cfg, backend="sglang", sub_component_type=component_type
)
# Set up resources
setup_worker_service_resources(worker_service, tp_size)
# Get and validate args
args = validate_and_get_worker_args(worker_service, backend="sglang")
# Set --tp argument
args = set_argument_value(args, "--tp", str(tp_size))
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def set_config_tep_size(
cls,
config: dict,
tep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
cfg, backend="sglang", sub_component_type=component_type
)
# Set up resources with multinode configuration
setup_worker_service_resources(worker_service, tep_size, num_gpus_per_node)
# Get and validate args
args = validate_and_get_worker_args(worker_service, backend="sglang")
# 1. Set --tp=tep_size, if not present add it
args = set_argument_value(args, "--tp", str(tep_size))
# 2. Set --ep-size=tep_size, if not present add it
args = set_argument_value(args, "--ep-size", str(tep_size))
# 3. Remove --dp if present
args = remove_valued_arguments(args, "--dp")
# 4. Remove --enable-dp-attention if present
if "--enable-dp-attention" in args:
args.remove("--enable-dp-attention")
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def set_config_dep_size(
cls,
config: dict,
dep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
cfg, backend="sglang", sub_component_type=component_type
)
# Set up resources with multinode configuration
setup_worker_service_resources(worker_service, dep_size, num_gpus_per_node)
# Get and validate args
args = validate_and_get_worker_args(worker_service, backend="sglang")
# 1. Set --tp=dep_size
args = set_argument_value(args, "--tp", str(dep_size))
# 2. Set --dp=dep_size (data parallelism across experts)
args = set_argument_value(args, "--dp", str(dep_size))
# 3. Enable --enable-dp-attention
if "--enable-dp-attention" not in args:
args = append_argument(args, "--enable-dp-attention")
# 4. Set --ep-size=dep_size (expert parallelism size)
args = set_argument_value(args, "--ep-size", str(dep_size))
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(cfg, backend="sglang")
args = validate_and_get_worker_args(worker_service, backend="sglang")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = break_arguments(args)
for i, arg in enumerate(args):
if arg == "--served-model-name" and i + 1 < len(args):
return args[i + 1]
logger.warning(
f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
@classmethod
def get_port(cls, config: dict) -> int:
cfg = Config.model_validate(config)
frontend_service = cfg.spec.services.get("Frontend")
if (
not frontend_service
or not frontend_service.extraPodSpec
or not frontend_service.extraPodSpec.mainContainer
):
logger.warning(
f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = frontend_service.extraPodSpec.mainContainer.args
if not args:
logger.warning(
f"No args found in Frontend configuration, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = break_arguments(args)
try:
idx = args.index("--http-port")
return int(args[idx + 1])
except (ValueError, IndexError):
logger.warning(
f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
@classmethod
def get_kv_cache_size_from_dynamo_log(
cls, dynamo_log_fn: str, attention_dp_size: int = 1
) -> int:
try:
with open(dynamo_log_fn, "r") as f:
for line in f:
if "KV Cache is allocated" in line and "#tokens:" in line:
# Extract the number after "#tokens:"
match = re.search(r"#tokens:\s*(\d+)", line)
if match:
return int(match.group(1)) * attention_dp_size
except Exception as e:
logger.warning(f"Failed to parse KV cache size from log file. Error: {e}")
return 0
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import re
from typing import Literal
from benchmarks.profiler.utils.config import (
Config,
append_argument,
break_arguments,
get_service_name_by_type,
get_worker_service_from_config,
parse_override_engine_args,
remove_valued_arguments,
setup_worker_service_resources,
validate_and_get_worker_args,
)
from benchmarks.profiler.utils.defaults import (
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
)
from dynamo.planner.defaults import SubComponentType
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
class TrtllmConfigModifier:
@classmethod
def convert_config(
cls,
config: dict,
target: Literal["prefill", "decode"],
is_moe_model: bool = False,
) -> dict:
if is_moe_model:
raise NotImplementedError(
"MoE model support is not implemented for TrtLLM backend"
)
cfg = Config.model_validate(config)
# set metadata name
cfg.metadata.name = "trtllm-agg"
# disable planner
if "Planner" in cfg.spec.services:
del cfg.spec.services["Planner"]
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.DECODE
)
# Convert to prefill-only aggregated setup
# Rename prefill worker to decode worker name
cfg.spec.services[decode_service_name] = cfg.spec.services[
prefill_service_name
]
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated mode (using decode worker for prefill-only)
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="trtllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="trtllm")
args = break_arguments(args)
# Remove disaggregation args
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-strategy")
# Keep the original extra-engine-args (prefill.yaml) which may contain user settings
# Check if user already has override-engine-args and merge with our changes
override_dict, args = parse_override_engine_args(args)
# Merge our overrides for converting prefill-only disagg to aggregated:
# - Disable enable_block_reuse (no KV reuse for prefill-only)
# - Enable overlap scheduler (disabled in prefill.yaml but needed for agg)
# - Remove cache_transceiver_config (not needed in agg mode)
if "kv_cache_config" not in override_dict:
override_dict["kv_cache_config"] = {}
override_dict["kv_cache_config"]["enable_block_reuse"] = False
override_dict[
"disable_overlap_scheduler"
] = False # Enable overlap scheduler for agg
override_dict[
"cache_transceiver_config"
] = None # Remove cache transceiver for agg
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = args
elif target == "decode":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.DECODE
)
# Convert to decode-only aggregated setup
# Remove prefill worker if exists
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated decode-only mode
cfg.spec.services[decode_service_name].subComponentType = "decode"
# Decode worker already has the correct name
worker_service = get_worker_service_from_config(
cfg,
backend="trtllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="trtllm")
args = break_arguments(args)
# Remove disaggregation args
args = remove_valued_arguments(args, "--disaggregation-mode")
args = remove_valued_arguments(args, "--disaggregation-strategy")
# Keep the original extra-engine-args (decode.yaml) which may contain user settings
# Check if user already has override-engine-args and merge with our changes
override_dict, args = parse_override_engine_args(args)
# Merge our overrides for converting decode-only disagg to aggregated:
# - Enable enable_block_reuse (to skip prefill in decode-only)
# - Remove cache_transceiver_config (not needed in agg mode)
if "kv_cache_config" not in override_dict:
override_dict["kv_cache_config"] = {}
override_dict["kv_cache_config"]["enable_block_reuse"] = True
override_dict[
"cache_transceiver_config"
] = None # Remove cache transceiver for agg
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = args
# Set num workers to 1
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg, "trtllm", SubComponentType.DECODE
)
worker_config = cfg.spec.services[final_decode_service_name]
worker_config.replicas = 1
return cfg.model_dump()
@classmethod
def set_config_tp_size(
cls,
config: dict,
tp_size: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
# Get the worker service using helper function
# This assumes convert_config has been called, so the service is named decode_worker_k8s_name
worker_service = get_worker_service_from_config(
cfg, backend="trtllm", sub_component_type=component_type
)
# Set up resources
setup_worker_service_resources(worker_service, tp_size)
# Validate and get args
args = validate_and_get_worker_args(worker_service, backend="trtllm")
# Break arguments to handle both joined strings and lists
args = break_arguments(args)
# For TRT-LLM, we need to update the override-engine-args
# to set the tensor_parallel_size
override_dict, args = parse_override_engine_args(args)
# Add/update tensor_parallel_size in the override
override_dict["tensor_parallel_size"] = tp_size
override_str = json.dumps(override_dict)
args = append_argument(args, ["--override-engine-args", override_str])
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def set_config_tep_size(
cls,
config: dict,
tep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
raise NotImplementedError(
"TEP (Tensor Expert Parallelism) is not implemented for TrtLLM backend"
)
@classmethod
def set_config_dep_size(
cls,
config: dict,
dep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
raise NotImplementedError(
"DEP (Data Expert Parallelism) is not implemented for TrtLLM backend"
)
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(cfg, backend="trtllm")
args = validate_and_get_worker_args(worker_service, backend="trtllm")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = break_arguments(args)
for i, arg in enumerate(args):
if arg == "--served-model-name" and i + 1 < len(args):
return args[i + 1]
logger.warning(
f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
@classmethod
def get_port(cls, config: dict) -> int:
cfg = Config.model_validate(config)
frontend_service = cfg.spec.services.get("Frontend")
if (
not frontend_service
or not frontend_service.extraPodSpec
or not frontend_service.extraPodSpec.mainContainer
):
logger.warning(
f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
# TRT-LLM frontend doesn't have args, it uses the default port
return DYNAMO_RUN_DEFAULT_PORT
@classmethod
def get_kv_cache_size_from_dynamo_log(
cls, dynamo_log_fn: str, attention_dp_size: int = 1
) -> int:
# TRT-LLM log parsing for KV cache size
# Format: [TensorRT-LLM][INFO] [MemUsageChange] Allocated XX GiB for max tokens in paged KV cache (XXXXXX).
try:
with open(dynamo_log_fn, "r") as f:
for line in f:
# Look for the specific TRT-LLM KV cache allocation log
if (
"Allocated" in line
and "for max tokens in paged KV cache" in line
):
# Extract the number in parentheses at the end
match = re.search(r"paged KV cache \((\d+)\)", line)
if match:
max_tokens = int(match.group(1))
logger.info(
f"Found TRT-LLM KV cache max tokens: {max_tokens}"
)
return max_tokens
except Exception as e:
logger.warning(f"Failed to parse KV cache size from log file. Error: {e}")
# Return a reasonable default if we couldn't find the KV cache size in logs
logger.warning(
"Could not find KV cache size in TRT-LLM logs, using default value of 100000"
)
return 100000 # Default fallback value for TRT-LLM
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Literal
from benchmarks.profiler.utils.config import (
Config,
append_argument,
break_arguments,
get_service_name_by_type,
get_worker_service_from_config,
setup_worker_service_resources,
validate_and_get_worker_args,
)
from benchmarks.profiler.utils.defaults import (
DEFAULT_MODEL_NAME,
DYNAMO_RUN_DEFAULT_PORT,
)
from dynamo.planner.defaults import SubComponentType
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
class VllmV1ConfigModifier:
@classmethod
def convert_config(
cls,
config: dict,
target: Literal["prefill", "decode"],
is_moe_model: bool = False,
) -> dict:
if is_moe_model:
raise NotImplementedError(
"MoE model support is not implemented for VLLM backend"
)
cfg = Config.model_validate(config)
# set metadata name
cfg.metadata.name = "vllm-agg"
# disable planner
if "Planner" in cfg.spec.services:
del cfg.spec.services["Planner"]
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.DECODE
)
# convert prefill worker into decode worker
cfg.spec.services[decode_service_name] = cfg.spec.services[
prefill_service_name
]
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated mode (using decode worker for prefill-only)
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="vllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args)
# remove --is-prefill-worker flag
args.remove("--is-prefill-worker")
# disable prefix caching
if "--enable-prefix-caching" in args:
args.remove("--enable-prefix-caching")
if "--no-enable-prefix-caching" not in args:
args = append_argument(args, "--no-enable-prefix-caching")
worker_service.extraPodSpec.mainContainer.args = args
elif target == "decode":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.DECODE
)
# delete prefill worker
del cfg.spec.services[prefill_service_name]
# Set subComponentType for aggregated decode-only mode
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg,
backend="vllm",
sub_component_type=SubComponentType.DECODE,
)
args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args)
# enable prefix caching
if "--enable-prefix-caching" not in args:
args = append_argument(args, "--enable-prefix-caching")
if "--no-enable-prefix-caching" in args:
args.remove("--no-enable-prefix-caching")
worker_service.extraPodSpec.mainContainer.args = args
# set num workers to 1
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg, "vllm", SubComponentType.DECODE
)
decode_worker_config = cfg.spec.services[final_decode_service_name]
decode_worker_config.replicas = 1
return cfg.model_dump()
@classmethod
def set_config_tp_size(
cls,
config: dict,
tp_size: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
cfg, backend="vllm", sub_component_type=component_type
)
# Set up resources
setup_worker_service_resources(worker_service, tp_size)
# Get and validate args
args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args)
try:
idx = args.index("--tensor-parallel-size")
args[idx + 1] = str(tp_size)
except ValueError:
args = append_argument(args, ["--tensor-parallel-size", str(tp_size)])
worker_service.extraPodSpec.mainContainer.args = args
return cfg.model_dump()
@classmethod
def set_config_tep_size(
cls,
config: dict,
tep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
raise NotImplementedError(
"TEP (Tensor Expert Parallelism) is not implemented for VLLM backend"
)
@classmethod
def set_config_dep_size(
cls,
config: dict,
dep_size: int,
num_gpus_per_node: int,
component_type: SubComponentType = SubComponentType.DECODE,
):
raise NotImplementedError(
"DEP (Data Expert Parallelism) is not implemented for VLLM backend"
)
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(cfg, backend="vllm")
args = validate_and_get_worker_args(worker_service, backend="vllm")
except (ValueError, KeyError):
logger.warning(
f"Worker service missing or invalid, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
args = break_arguments(args)
for i, arg in enumerate(args):
if arg == "--model" and i + 1 < len(args):
return args[i + 1]
logger.warning(
f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
)
return DEFAULT_MODEL_NAME
@classmethod
def get_port(cls, config: dict) -> int:
cfg = Config.model_validate(config)
frontend_service = cfg.spec.services.get("Frontend")
if (
not frontend_service
or not frontend_service.extraPodSpec
or not frontend_service.extraPodSpec.mainContainer
):
logger.warning(
f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = frontend_service.extraPodSpec.mainContainer.args
if not args:
logger.warning(
f"No args found in Frontend configuration, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
args = break_arguments(args)
try:
idx = args.index("--http-port")
return int(args[idx + 1])
except (ValueError, IndexError):
logger.warning(
f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
)
return DYNAMO_RUN_DEFAULT_PORT
@classmethod
def get_kv_cache_size_from_dynamo_log(
cls, dynamo_log_fn: str, attention_dp_size: int = 1
) -> int:
try:
with open(dynamo_log_fn, "r") as f:
for line in f:
if "Maximum concurrency for" in line:
line = line.strip().split("Maximum concurrency for ")[1]
token_count = int(
line.split(" tokens per request: ")[0].replace(",", "")
)
concurrency = float(line.split(" tokens per request: ")[1][:-1])
logger.info(
f"Found KV cache info: {token_count} x {concurrency} = {int(token_count * concurrency)}"
)
return int(token_count * concurrency)
except Exception as e:
logger.warning(
f"Failed to parse KV cache size from line: {line}. Error: {e}"
)
return 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment