Unverified Commit bd28e442 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

fix: fix config updates on deep copy not propagating back (#3512)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent cdfd70f0
...@@ -901,4 +901,14 @@ if __name__ == "__main__": ...@@ -901,4 +901,14 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# setup file logging
os.makedirs(args.output_dir, exist_ok=True)
log_file_handler = logging.FileHandler(f"{args.output_dir}/profile_sla.log")
log_file_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
log_file_handler.setFormatter(formatter)
logger.addHandler(log_file_handler)
asyncio.run(run_profile(args)) asyncio.run(run_profile(args))
...@@ -224,14 +224,14 @@ def set_multinode_config(worker_service, gpu_count: int, num_gpus_per_node: int) ...@@ -224,14 +224,14 @@ def set_multinode_config(worker_service, gpu_count: int, num_gpus_per_node: int)
def get_service_name_by_type( def get_service_name_by_type(
config: dict, backend: str, sub_component_type: SubComponentType config: Config, backend: str, sub_component_type: SubComponentType
) -> str: ) -> str:
"""Helper function to get service name by subComponentType. """Helper function to get service name by subComponentType.
First tries to find service by subComponentType, then falls back to component name. First tries to find service by subComponentType, then falls back to component name.
Args: Args:
config: Configuration dictionary (with spec.services structure) config: Configuration object
backend: Backend name (e.g., "sglang", "vllm", "trtllm") backend: Backend name (e.g., "sglang", "vllm", "trtllm")
sub_component_type: The type of sub-component to look for (PREFILL or DECODE) sub_component_type: The type of sub-component to look for (PREFILL or DECODE)
...@@ -239,11 +239,7 @@ def get_service_name_by_type( ...@@ -239,11 +239,7 @@ def get_service_name_by_type(
The service name The service name
""" """
# Check if config has the expected structure # Check if config has the expected structure
if ( if not config.spec or not config.spec.services:
not isinstance(config, dict)
or "spec" not in config
or "services" not in config.get("spec", {})
):
# Fall back to default name if structure is unexpected # Fall back to default name if structure is unexpected
if sub_component_type == SubComponentType.DECODE: if sub_component_type == SubComponentType.DECODE:
return WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name return WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name
...@@ -251,12 +247,9 @@ def get_service_name_by_type( ...@@ -251,12 +247,9 @@ def get_service_name_by_type(
return WORKER_COMPONENT_NAMES[backend].prefill_worker_k8s_name return WORKER_COMPONENT_NAMES[backend].prefill_worker_k8s_name
# Look through services to find one with matching subComponentType # Look through services to find one with matching subComponentType
services = config["spec"]["services"] services = config.spec.services
for service_name, service_config in services.items(): for service_name, service_config in services.items():
if ( if service_config.subComponentType == sub_component_type.value:
isinstance(service_config, dict)
and service_config.get("subComponentType") == sub_component_type.value
):
return service_name return service_name
# Fall back to default component names # Fall back to default component names
...@@ -274,7 +267,7 @@ def get_service_name_by_type( ...@@ -274,7 +267,7 @@ def get_service_name_by_type(
def get_worker_service_from_config( def get_worker_service_from_config(
config: dict, config: Config,
backend: str = "sglang", backend: str = "sglang",
sub_component_type: SubComponentType = SubComponentType.DECODE, sub_component_type: SubComponentType = SubComponentType.DECODE,
): ):
...@@ -299,8 +292,7 @@ def get_worker_service_from_config( ...@@ -299,8 +292,7 @@ def get_worker_service_from_config(
service_name = get_service_name_by_type(config, backend, sub_component_type) service_name = get_service_name_by_type(config, backend, sub_component_type)
# Get the actual service from the config # Get the actual service from the config
cfg = Config.model_validate(config) return config.spec.services[service_name]
return cfg.spec.services[service_name]
def setup_worker_service_resources( def setup_worker_service_resources(
...@@ -445,10 +437,10 @@ class VllmV1ConfigModifier: ...@@ -445,10 +437,10 @@ class VllmV1ConfigModifier:
if target == "prefill": if target == "prefill":
# Get service names by inferring from subComponentType first # Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type( prefill_service_name = get_service_name_by_type(
config, "vllm", SubComponentType.PREFILL cfg, "vllm", SubComponentType.PREFILL
) )
decode_service_name = get_service_name_by_type( decode_service_name = get_service_name_by_type(
config, "vllm", SubComponentType.DECODE cfg, "vllm", SubComponentType.DECODE
) )
# convert prefill worker into decode worker # convert prefill worker into decode worker
...@@ -461,7 +453,7 @@ class VllmV1ConfigModifier: ...@@ -461,7 +453,7 @@ class VllmV1ConfigModifier:
cfg.spec.services[decode_service_name].subComponentType = "decode" cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config( worker_service = get_worker_service_from_config(
cfg.model_dump(), cfg,
backend="vllm", backend="vllm",
sub_component_type=SubComponentType.DECODE, sub_component_type=SubComponentType.DECODE,
) )
...@@ -482,10 +474,10 @@ class VllmV1ConfigModifier: ...@@ -482,10 +474,10 @@ class VllmV1ConfigModifier:
elif target == "decode": elif target == "decode":
# Get service names by inferring from subComponentType first # Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type( prefill_service_name = get_service_name_by_type(
config, "vllm", SubComponentType.PREFILL cfg, "vllm", SubComponentType.PREFILL
) )
decode_service_name = get_service_name_by_type( decode_service_name = get_service_name_by_type(
config, "vllm", SubComponentType.DECODE cfg, "vllm", SubComponentType.DECODE
) )
# delete prefill worker # delete prefill worker
...@@ -495,7 +487,7 @@ class VllmV1ConfigModifier: ...@@ -495,7 +487,7 @@ class VllmV1ConfigModifier:
cfg.spec.services[decode_service_name].subComponentType = "decode" cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config( worker_service = get_worker_service_from_config(
cfg.model_dump(), cfg,
backend="vllm", backend="vllm",
sub_component_type=SubComponentType.DECODE, sub_component_type=SubComponentType.DECODE,
) )
...@@ -513,7 +505,7 @@ class VllmV1ConfigModifier: ...@@ -513,7 +505,7 @@ class VllmV1ConfigModifier:
# set num workers to 1 # set num workers to 1
# Use the inferred decode service name # Use the inferred decode service name
final_decode_service_name = get_service_name_by_type( final_decode_service_name = get_service_name_by_type(
cfg.model_dump(), "vllm", SubComponentType.DECODE cfg, "vllm", SubComponentType.DECODE
) )
decode_worker_config = cfg.spec.services[final_decode_service_name] decode_worker_config = cfg.spec.services[final_decode_service_name]
decode_worker_config.replicas = 1 decode_worker_config.replicas = 1
...@@ -529,7 +521,7 @@ class VllmV1ConfigModifier: ...@@ -529,7 +521,7 @@ class VllmV1ConfigModifier:
): ):
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config( worker_service = get_worker_service_from_config(
config, backend="vllm", sub_component_type=component_type cfg, backend="vllm", sub_component_type=component_type
) )
# Set up resources # Set up resources
...@@ -575,8 +567,9 @@ class VllmV1ConfigModifier: ...@@ -575,8 +567,9 @@ class VllmV1ConfigModifier:
@classmethod @classmethod
def get_model_name(cls, config: dict) -> str: def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try: try:
worker_service = get_worker_service_from_config(config, backend="vllm") worker_service = get_worker_service_from_config(cfg, backend="vllm")
args = validate_and_get_worker_args(worker_service, backend="vllm") args = validate_and_get_worker_args(worker_service, backend="vllm")
except (ValueError, KeyError): except (ValueError, KeyError):
logger.warning( logger.warning(
...@@ -670,10 +663,10 @@ class SGLangConfigModifier: ...@@ -670,10 +663,10 @@ class SGLangConfigModifier:
if target == "prefill": if target == "prefill":
# Get service names by inferring from subComponentType first # Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type( prefill_service_name = get_service_name_by_type(
config, "sglang", SubComponentType.PREFILL cfg, "sglang", SubComponentType.PREFILL
) )
decode_service_name = get_service_name_by_type( decode_service_name = get_service_name_by_type(
config, "sglang", SubComponentType.DECODE cfg, "sglang", SubComponentType.DECODE
) )
# convert prefill worker into decode worker # convert prefill worker into decode worker
...@@ -686,7 +679,7 @@ class SGLangConfigModifier: ...@@ -686,7 +679,7 @@ class SGLangConfigModifier:
cfg.spec.services[decode_service_name].subComponentType = "decode" cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config( worker_service = get_worker_service_from_config(
cfg.model_dump(), cfg,
backend="sglang", backend="sglang",
sub_component_type=SubComponentType.DECODE, sub_component_type=SubComponentType.DECODE,
) )
...@@ -707,10 +700,10 @@ class SGLangConfigModifier: ...@@ -707,10 +700,10 @@ class SGLangConfigModifier:
elif target == "decode": elif target == "decode":
# Get service names by inferring from subComponentType first # Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type( prefill_service_name = get_service_name_by_type(
config, "sglang", SubComponentType.PREFILL cfg, "sglang", SubComponentType.PREFILL
) )
decode_service_name = get_service_name_by_type( decode_service_name = get_service_name_by_type(
config, "sglang", SubComponentType.DECODE cfg, "sglang", SubComponentType.DECODE
) )
# delete prefill worker # delete prefill worker
...@@ -720,7 +713,7 @@ class SGLangConfigModifier: ...@@ -720,7 +713,7 @@ class SGLangConfigModifier:
cfg.spec.services[decode_service_name].subComponentType = "decode" cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config( worker_service = get_worker_service_from_config(
cfg.model_dump(), cfg,
backend="sglang", backend="sglang",
sub_component_type=SubComponentType.DECODE, sub_component_type=SubComponentType.DECODE,
) )
...@@ -751,7 +744,7 @@ class SGLangConfigModifier: ...@@ -751,7 +744,7 @@ class SGLangConfigModifier:
# set num workers to 1 # set num workers to 1
# Use the inferred decode service name # Use the inferred decode service name
final_decode_service_name = get_service_name_by_type( final_decode_service_name = get_service_name_by_type(
cfg.model_dump(), "sglang", SubComponentType.DECODE cfg, "sglang", SubComponentType.DECODE
) )
decode_worker_config = cfg.spec.services[final_decode_service_name] decode_worker_config = cfg.spec.services[final_decode_service_name]
decode_worker_config.replicas = 1 decode_worker_config.replicas = 1
...@@ -767,7 +760,7 @@ class SGLangConfigModifier: ...@@ -767,7 +760,7 @@ class SGLangConfigModifier:
): ):
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config( worker_service = get_worker_service_from_config(
config, backend="sglang", sub_component_type=component_type cfg, backend="sglang", sub_component_type=component_type
) )
# Set up resources # Set up resources
...@@ -792,7 +785,7 @@ class SGLangConfigModifier: ...@@ -792,7 +785,7 @@ class SGLangConfigModifier:
): ):
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config( worker_service = get_worker_service_from_config(
config, backend="sglang", sub_component_type=component_type cfg, backend="sglang", sub_component_type=component_type
) )
# Set up resources with multinode configuration # Set up resources with multinode configuration
...@@ -827,7 +820,7 @@ class SGLangConfigModifier: ...@@ -827,7 +820,7 @@ class SGLangConfigModifier:
): ):
cfg = Config.model_validate(config) cfg = Config.model_validate(config)
worker_service = get_worker_service_from_config( worker_service = get_worker_service_from_config(
config, backend="sglang", sub_component_type=component_type cfg, backend="sglang", sub_component_type=component_type
) )
# Set up resources with multinode configuration # Set up resources with multinode configuration
...@@ -854,8 +847,9 @@ class SGLangConfigModifier: ...@@ -854,8 +847,9 @@ class SGLangConfigModifier:
@classmethod @classmethod
def get_model_name(cls, config: dict) -> str: def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try: try:
worker_service = get_worker_service_from_config(config, backend="sglang") worker_service = get_worker_service_from_config(cfg, backend="sglang")
args = validate_and_get_worker_args(worker_service, backend="sglang") args = validate_and_get_worker_args(worker_service, backend="sglang")
except (ValueError, KeyError): except (ValueError, KeyError):
logger.warning( logger.warning(
...@@ -946,10 +940,10 @@ class TrtllmConfigModifier: ...@@ -946,10 +940,10 @@ class TrtllmConfigModifier:
if target == "prefill": if target == "prefill":
# Get service names by inferring from subComponentType first # Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type( prefill_service_name = get_service_name_by_type(
config, "trtllm", SubComponentType.PREFILL cfg, "trtllm", SubComponentType.PREFILL
) )
decode_service_name = get_service_name_by_type( decode_service_name = get_service_name_by_type(
config, "trtllm", SubComponentType.DECODE cfg, "trtllm", SubComponentType.DECODE
) )
# Convert to prefill-only aggregated setup # Convert to prefill-only aggregated setup
...@@ -963,7 +957,7 @@ class TrtllmConfigModifier: ...@@ -963,7 +957,7 @@ class TrtllmConfigModifier:
cfg.spec.services[decode_service_name].subComponentType = "decode" cfg.spec.services[decode_service_name].subComponentType = "decode"
worker_service = get_worker_service_from_config( worker_service = get_worker_service_from_config(
cfg.model_dump(), cfg,
backend="trtllm", backend="trtllm",
sub_component_type=SubComponentType.DECODE, sub_component_type=SubComponentType.DECODE,
) )
...@@ -1000,10 +994,10 @@ class TrtllmConfigModifier: ...@@ -1000,10 +994,10 @@ class TrtllmConfigModifier:
elif target == "decode": elif target == "decode":
# Get service names by inferring from subComponentType first # Get service names by inferring from subComponentType first
prefill_service_name = get_service_name_by_type( prefill_service_name = get_service_name_by_type(
config, "trtllm", SubComponentType.PREFILL cfg, "trtllm", SubComponentType.PREFILL
) )
decode_service_name = get_service_name_by_type( decode_service_name = get_service_name_by_type(
config, "trtllm", SubComponentType.DECODE cfg, "trtllm", SubComponentType.DECODE
) )
# Convert to decode-only aggregated setup # Convert to decode-only aggregated setup
...@@ -1015,7 +1009,7 @@ class TrtllmConfigModifier: ...@@ -1015,7 +1009,7 @@ class TrtllmConfigModifier:
# Decode worker already has the correct name # Decode worker already has the correct name
worker_service = get_worker_service_from_config( worker_service = get_worker_service_from_config(
cfg.model_dump(), cfg,
backend="trtllm", backend="trtllm",
sub_component_type=SubComponentType.DECODE, sub_component_type=SubComponentType.DECODE,
) )
...@@ -1048,7 +1042,7 @@ class TrtllmConfigModifier: ...@@ -1048,7 +1042,7 @@ class TrtllmConfigModifier:
# Set num workers to 1 # Set num workers to 1
# Use the inferred decode service name # Use the inferred decode service name
final_decode_service_name = get_service_name_by_type( final_decode_service_name = get_service_name_by_type(
cfg.model_dump(), "trtllm", SubComponentType.DECODE cfg, "trtllm", SubComponentType.DECODE
) )
worker_config = cfg.spec.services[final_decode_service_name] worker_config = cfg.spec.services[final_decode_service_name]
worker_config.replicas = 1 worker_config.replicas = 1
...@@ -1067,7 +1061,7 @@ class TrtllmConfigModifier: ...@@ -1067,7 +1061,7 @@ class TrtllmConfigModifier:
# Get the worker service using helper function # Get the worker service using helper function
# This assumes convert_config has been called, so the service is named decode_worker_k8s_name # This assumes convert_config has been called, so the service is named decode_worker_k8s_name
worker_service = get_worker_service_from_config( worker_service = get_worker_service_from_config(
config, backend="trtllm", sub_component_type=component_type cfg, backend="trtllm", sub_component_type=component_type
) )
# Set up resources # Set up resources
...@@ -1118,8 +1112,9 @@ class TrtllmConfigModifier: ...@@ -1118,8 +1112,9 @@ class TrtllmConfigModifier:
@classmethod @classmethod
def get_model_name(cls, config: dict) -> str: def get_model_name(cls, config: dict) -> str:
cfg = Config.model_validate(config)
try: try:
worker_service = get_worker_service_from_config(config, backend="trtllm") worker_service = get_worker_service_from_config(cfg, backend="trtllm")
args = validate_and_get_worker_args(worker_service, backend="trtllm") args = validate_and_get_worker_args(worker_service, backend="trtllm")
except (ValueError, KeyError): except (ValueError, KeyError):
logger.warning( logger.warning(
......
...@@ -18,6 +18,7 @@ spec: ...@@ -18,6 +18,7 @@ spec:
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
dynamoNamespace: sglang-disagg dynamoNamespace: sglang-disagg
componentType: worker componentType: worker
subComponentType: decode
replicas: 1 replicas: 1
resources: resources:
limits: limits:
...@@ -50,6 +51,7 @@ spec: ...@@ -50,6 +51,7 @@ spec:
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
dynamoNamespace: sglang-disagg dynamoNamespace: sglang-disagg
componentType: worker componentType: worker
subComponentType: prefill
replicas: 1 replicas: 1
resources: resources:
limits: limits:
...@@ -59,7 +61,7 @@ spec: ...@@ -59,7 +61,7 @@ spec:
image: my-registry/sglang-runtime:my-tag image: my-registry/sglang-runtime:my-tag
workingDir: /workspace/components/backends/sglang workingDir: /workspace/components/backends/sglang
command: command:
- python3 - python3E
- -m - -m
- dynamo.sglang - dynamo.sglang
args: args:
......
...@@ -18,6 +18,7 @@ spec: ...@@ -18,6 +18,7 @@ spec:
dynamoNamespace: trtllm-disagg dynamoNamespace: trtllm-disagg
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
componentType: worker componentType: worker
subComponentType: prefill
replicas: 1 replicas: 1
resources: resources:
limits: limits:
...@@ -45,6 +46,7 @@ spec: ...@@ -45,6 +46,7 @@ spec:
dynamoNamespace: trtllm-disagg dynamoNamespace: trtllm-disagg
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
componentType: worker componentType: worker
subComponentType: decode
replicas: 1 replicas: 1
resources: resources:
limits: limits:
......
...@@ -18,6 +18,7 @@ spec: ...@@ -18,6 +18,7 @@ spec:
dynamoNamespace: vllm-disagg dynamoNamespace: vllm-disagg
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
componentType: worker componentType: worker
subComponentType: decode
replicas: 1 replicas: 1
resources: resources:
limits: limits:
...@@ -37,6 +38,7 @@ spec: ...@@ -37,6 +38,7 @@ spec:
dynamoNamespace: vllm-disagg dynamoNamespace: vllm-disagg
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
componentType: worker componentType: worker
subComponentType: prefill
replicas: 1 replicas: 1
resources: resources:
limits: limits:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment