Unverified Commit ce1a6f49 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

fix: fallback for agg DGD generation without subComponentType (#6317)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent abacb96e
......@@ -303,19 +303,10 @@ class BaseConfigModifier:
- update_model()
- update_model_from_pvc()
"""
# Update workers (prefill + decode) if present.
for sct in (SubComponentType.PREFILL, SubComponentType.DECODE):
try:
svc_name = get_service_name_by_type(cfg, cls.BACKEND, sct)
except Exception:
continue
if svc_name not in cfg.spec.services:
continue
service = cfg.spec.services[svc_name]
def _patch_service(service: Any) -> None:
if not service.extraPodSpec or not service.extraPodSpec.mainContainer:
continue
return
c = service.extraPodSpec.mainContainer
def _patch(tokens: list[str]) -> list[str]:
......@@ -329,6 +320,26 @@ class BaseConfigModifier:
cls._update_container_args_preserving_shell_form(c, _patch)
# Update workers (prefill + decode) if present.
patched_services: set[str] = set()
for sct in (SubComponentType.PREFILL, SubComponentType.DECODE):
try:
svc_name = get_service_name_by_type(cfg, cls.BACKEND, sct)
except Exception:
continue
if svc_name not in cfg.spec.services:
continue
_patch_service(cfg.spec.services[svc_name])
patched_services.add(svc_name)
# Fallback for agg mode: if no worker was patched via subComponentType
# lookup, patch any non-Frontend/Planner worker service.
if not patched_services:
for name, service in cfg.spec.services.items():
if name not in cls._NON_WORKER_SERVICES:
_patch_service(service)
patched_services.add(name)
if patch_frontend:
cls._update_frontend_cli(cfg, model_name=model_name, model_path=model_path)
......@@ -598,8 +609,20 @@ class BaseConfigModifier:
agg_replicas: int,
agg_gpus: int,
) -> None:
"""Apply CLI args, replicas, and GPU resources to the agg worker service."""
"""Apply CLI args, replicas, and GPU resources to the agg worker service.
In agg mode, the default config template may use a generic worker
service name (e.g. ``TRTLLMWorker``) that does not match the disagg
naming convention (``TRTLLMDecodeWorker``). We first try the standard
DECODE lookup, then fall back to any non-Frontend/Planner service.
"""
svc_name = cls._resolve_service_name(cfg, SubComponentType.DECODE)
if svc_name is None or svc_name not in cfg.spec.services:
# Fallback: find any worker service in the config
for name in cfg.spec.services:
if name not in cls._NON_WORKER_SERVICES:
svc_name = name
break
if svc_name is None or svc_name not in cfg.spec.services:
logger.warning("Could not find worker service for agg mode")
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment