Unverified Commit bd28e442 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

fix: fix config updates on deep copy not propagating back (#3512)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent cdfd70f0
......@@ -901,4 +901,14 @@ if __name__ == "__main__":
)
args = parser.parse_args()
# setup file logging
os.makedirs(args.output_dir, exist_ok=True)
log_file_handler = logging.FileHandler(f"{args.output_dir}/profile_sla.log")
log_file_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
log_file_handler.setFormatter(formatter)
logger.addHandler(log_file_handler)
asyncio.run(run_profile(args))
......@@ -224,14 +224,14 @@ def set_multinode_config(worker_service, gpu_count: int, num_gpus_per_node: int)
def get_service_name_by_type(
config: dict, backend: str, sub_component_type: SubComponentType
config: Config, backend: str, sub_component_type: SubComponentType
) -> str:
"""Helper function to get service name by subComponentType.
First tries to find service by subComponentType, then falls back to component name.
Args:
config: Configuration dictionary (with spec.services structure)
config: Configuration object
backend: Backend name (e.g., "sglang", "vllm", "trtllm")
sub_component_type: The type of sub-component to look for (PREFILL or DECODE)
......@@ -239,11 +239,7 @@ def get_service_name_by_type(
The service name
"""
# Check if config has the expected structure
if (
not isinstance(config, dict)
or "spec" not in config
or "services" not in config.get("spec", {})
):
if not config.spec or not config.spec.services:
# Fall back to default name if structure is unexpected
if sub_component_type == SubComponentType.DECODE:
return WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name
......@@ -251,12 +247,9 @@ def get_service_name_by_type(
return WORKER_COMPONENT_NAMES[backend].prefill_worker_k8s_name
# Look through services to find one with matching subComponentType
services = config["spec"]["services"]
services = config.spec.services
for service_name, service_config in services.items():
if (
isinstance(service_config, dict)
and service_config.get("subComponentType") == sub_component_type.value
):
if service_config.subComponentType == sub_component_type.value:
return service_name
# Fall back to default component names
......@@ -274,7 +267,7 @@ def get_service_name_by_type(
def get_worker_service_from_config(
config: dict,
config: Config,
backend: str = "sglang",
sub_component_type: SubComponentType = SubComponentType.DECODE,
):
......@@ -299,8 +292,7 @@ def get_worker_service_from_config(
service_name = get_service_name_by_type(config, backend, sub_component_type)
# Get the actual service from the config
cfg = Config.model_validate(config)
return cfg.spec.services[service_name]
return config.spec.services[service_name]
def setup_worker_service_resources(
......@@ -445,10 +437,10 @@ class VllmV1ConfigModifier:
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
config, "vllm", SubComponentType.PREFILL
cfg, "vllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
config, "vllm", SubComponentType.DECODE
cfg, "vllm", SubComponentType.DECODE
)
# convert prefill worker into decode worker
......@@ -461,7 +453,7 @@ class VllmV1ConfigModifier:
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg.model_dump(),
cfg,
backend="vllm",
sub_component_type=SubComponentType.DECODE,
)
......@@ -482,10 +474,10 @@ class VllmV1ConfigModifier:
elif target == "decode":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
config, "vllm", SubComponentType.PREFILL
cfg, "vllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
config, "vllm", SubComponentType.DECODE
cfg, "vllm", SubComponentType.DECODE
)
# delete prefill worker
......@@ -495,7 +487,7 @@ class VllmV1ConfigModifier:
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg.model_dump(),
cfg,
backend="vllm",
sub_component_type=SubComponentType.DECODE,
)
......@@ -513,7 +505,7 @@ class VllmV1ConfigModifier:
# set num workers to 1
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg.model_dump(), "vllm", SubComponentType.DECODE
cfg, "vllm", SubComponentType.DECODE
)
decode_worker_config = cfg.spec.services[final_decode_service_name]
decode_worker_config.replicas = 1
......@@ -529,7 +521,7 @@ class VllmV1ConfigModifier:
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
config, backend="vllm", sub_component_type=component_type
cfg, backend="vllm", sub_component_type=component_type
)
# Set up resources
......@@ -575,8 +567,9 @@ class VllmV1ConfigModifier:
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(config, backend="vllm")
worker_service = get_worker_service_from_config(cfg, backend="vllm")
args = validate_and_get_worker_args(worker_service, backend="vllm")
except (ValueError, KeyError):
logger.warning(
......@@ -670,10 +663,10 @@ class SGLangConfigModifier:
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
config, "sglang", SubComponentType.PREFILL
cfg, "sglang", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
config, "sglang", SubComponentType.DECODE
cfg, "sglang", SubComponentType.DECODE
)
# convert prefill worker into decode worker
......@@ -686,7 +679,7 @@ class SGLangConfigModifier:
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg.model_dump(),
cfg,
backend="sglang",
sub_component_type=SubComponentType.DECODE,
)
......@@ -707,10 +700,10 @@ class SGLangConfigModifier:
elif target == "decode":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
config, "sglang", SubComponentType.PREFILL
cfg, "sglang", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
config, "sglang", SubComponentType.DECODE
cfg, "sglang", SubComponentType.DECODE
)
# delete prefill worker
......@@ -720,7 +713,7 @@ class SGLangConfigModifier:
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg.model_dump(),
cfg,
backend="sglang",
sub_component_type=SubComponentType.DECODE,
)
......@@ -751,7 +744,7 @@ class SGLangConfigModifier:
# set num workers to 1
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg.model_dump(), "sglang", SubComponentType.DECODE
cfg, "sglang", SubComponentType.DECODE
)
decode_worker_config = cfg.spec.services[final_decode_service_name]
decode_worker_config.replicas = 1
......@@ -767,7 +760,7 @@ class SGLangConfigModifier:
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
config, backend="sglang", sub_component_type=component_type
cfg, backend="sglang", sub_component_type=component_type
)
# Set up resources
......@@ -792,7 +785,7 @@ class SGLangConfigModifier:
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
config, backend="sglang", sub_component_type=component_type
cfg, backend="sglang", sub_component_type=component_type
)
# Set up resources with multinode configuration
......@@ -827,7 +820,7 @@ class SGLangConfigModifier:
):
cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config(
config, backend="sglang", sub_component_type=component_type
cfg, backend="sglang", sub_component_type=component_type
)
# Set up resources with multinode configuration
......@@ -854,8 +847,9 @@ class SGLangConfigModifier:
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(config, backend="sglang")
worker_service = get_worker_service_from_config(cfg, backend="sglang")
args = validate_and_get_worker_args(worker_service, backend="sglang")
except (ValueError, KeyError):
logger.warning(
......@@ -946,10 +940,10 @@ class TrtllmConfigModifier:
if target == "prefill":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
config, "trtllm", SubComponentType.PREFILL
cfg, "trtllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
config, "trtllm", SubComponentType.DECODE
cfg, "trtllm", SubComponentType.DECODE
)
# Convert to prefill-only aggregated setup
......@@ -963,7 +957,7 @@ class TrtllmConfigModifier:
cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config(
cfg.model_dump(),
cfg,
backend="trtllm",
sub_component_type=SubComponentType.DECODE,
)
......@@ -1000,10 +994,10 @@ class TrtllmConfigModifier:
elif target == "decode":
# Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type(
config, "trtllm", SubComponentType.PREFILL
cfg, "trtllm", SubComponentType.PREFILL
)
decode_service_name = get_service_name_by_type(
config, "trtllm", SubComponentType.DECODE
cfg, "trtllm", SubComponentType.DECODE
)
# Convert to decode-only aggregated setup
......@@ -1015,7 +1009,7 @@ class TrtllmConfigModifier:
# Decode worker already has the correct name
worker_service = get_worker_service_from_config(
cfg.model_dump(),
cfg,
backend="trtllm",
sub_component_type=SubComponentType.DECODE,
)
......@@ -1048,7 +1042,7 @@ class TrtllmConfigModifier:
# Set num workers to 1
# Use the inferred decode service name
final_decode_service_name = get_service_name_by_type(
cfg.model_dump(), "trtllm", SubComponentType.DECODE
cfg, "trtllm", SubComponentType.DECODE
)
worker_config = cfg.spec.services[final_decode_service_name]
worker_config.replicas = 1
......@@ -1067,7 +1061,7 @@ class TrtllmConfigModifier:
# Get the worker service using helper function
# This assumes convert_config has been called, so the service is named decode_worker_k8s_name
worker_service = get_worker_service_from_config(
config, backend="trtllm", sub_component_type=component_type
cfg, backend="trtllm", sub_component_type=component_type
)
# Set up resources
......@@ -1118,8 +1112,9 @@ class TrtllmConfigModifier:
@classmethod
def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try:
worker_service = get_worker_service_from_config(config, backend="trtllm")
worker_service = get_worker_service_from_config(cfg, backend="trtllm")
args = validate_and_get_worker_args(worker_service, backend="trtllm")
except (ValueError, KeyError):
logger.warning(
......
......@@ -18,6 +18,7 @@ spec:
envFromSecret: hf-token-secret
dynamoNamespace: sglang-disagg
componentType: worker
subComponentType: decode
replicas: 1
resources:
limits:
......@@ -50,6 +51,7 @@ spec:
envFromSecret: hf-token-secret
dynamoNamespace: sglang-disagg
componentType: worker
subComponentType: prefill
replicas: 1
resources:
limits:
......@@ -59,7 +61,7 @@ spec:
image: my-registry/sglang-runtime:my-tag
workingDir: /workspace/components/backends/sglang
command:
- python3
- python3E
- -m
- dynamo.sglang
args:
......
......@@ -18,6 +18,7 @@ spec:
dynamoNamespace: trtllm-disagg
envFromSecret: hf-token-secret
componentType: worker
subComponentType: prefill
replicas: 1
resources:
limits:
......@@ -45,6 +46,7 @@ spec:
dynamoNamespace: trtllm-disagg
envFromSecret: hf-token-secret
componentType: worker
subComponentType: decode
replicas: 1
resources:
limits:
......
......@@ -18,6 +18,7 @@ spec:
dynamoNamespace: vllm-disagg
envFromSecret: hf-token-secret
componentType: worker
subComponentType: decode
replicas: 1
resources:
limits:
......@@ -37,6 +38,7 @@ spec:
dynamoNamespace: vllm-disagg
envFromSecret: hf-token-secret
componentType: worker
subComponentType: prefill
replicas: 1
resources:
limits:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment