Unverified Commit 6945d47e authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

fix(ft-tests): avoid shlex.split corrupting JSON args in ServiceSpec setters (#7012)


Signed-off-by: default avatarTzu-Ling <tzulingk@nvidia.com>
parent 419e936a
...@@ -6,7 +6,6 @@ import logging ...@@ -6,7 +6,6 @@ import logging
import os import os
import re import re
import secrets import secrets
import shlex
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, List, Optional from typing import Any, List, Optional
...@@ -76,6 +75,25 @@ class ServiceSpec: ...@@ -76,6 +75,25 @@ class ServiceSpec:
def envs(self, value: list[dict[str, str]]): def envs(self, value: list[dict[str, str]]):
self._spec["envs"] = value self._spec["envs"] = value
def _get_args(self) -> list[str]:
"""Return the container args list, normalising scalar strings to a list in-place.
Always returns the same list object that is stored in the spec, so
in-place mutations (append / index assignment) are reflected immediately
without an explicit writeback.
"""
try:
container = self._spec["extraPodSpec"]["mainContainer"]
except KeyError:
return []
if "args" not in container:
container["args"] = []
args = container["args"]
if isinstance(args, str):
args = args.split()
container["args"] = args
return args
# ----- Replicas ----- # ----- Replicas -----
@property @property
def replicas(self) -> int: def replicas(self) -> int:
...@@ -88,47 +106,21 @@ class ServiceSpec: ...@@ -88,47 +106,21 @@ class ServiceSpec:
@property @property
def model(self) -> Optional[str]: def model(self) -> Optional[str]:
"""Model being served by this service (checks both --model and --model-path)""" """Model being served by this service (checks both --model and --model-path)"""
try: args = self._get_args()
args_list = self._spec["extraPodSpec"]["mainContainer"]["args"] for i, arg in enumerate(args):
except KeyError: if arg in ["--model", "--model-path"]:
return None if i + 1 < len(args) and not args[i + 1].startswith("-"):
args_str = " ".join(args_list) return args[i + 1]
parts = shlex.split(args_str)
for i, part in enumerate(parts):
if part in ["--model", "--model-path"]:
return parts[i + 1] if i + 1 < len(parts) else None
return None return None
@model.setter @model.setter
def model(self, value: str): def model(self, value: str):
if "extraPodSpec" not in self._spec: args = self._get_args()
return for i, arg in enumerate(args):
if "mainContainer" not in self._spec["extraPodSpec"]: if arg in ["--model", "--model-path"]:
return if i + 1 < len(args) and not args[i + 1].startswith("-"):
args[i + 1] = value
args_list = self._spec["extraPodSpec"]["mainContainer"].get("args", [])
args_str = " ".join(args_list)
parts = shlex.split(args_str)
# Try to update --model first, then --model-path
model_index = None
for i, part in enumerate(parts):
if part in ["--model", "--model-path"]:
model_index = i
break
if model_index is not None:
if model_index + 1 < len(parts):
parts[model_index + 1] = value
else:
return return
else:
return
# Store args as a list of separate strings for proper command-line parsing
# WRONG: [" ".join(parts)] creates ["--model Qwen/Qwen3-0.6B"] (single string)
# RIGHT: parts creates ["--model", "Qwen/Qwen3-0.6B"] (separate strings)
self._spec["extraPodSpec"]["mainContainer"]["args"] = parts
# ----- GPUs ----- # ----- GPUs -----
@property @property
...@@ -149,54 +141,26 @@ class ServiceSpec: ...@@ -149,54 +141,26 @@ class ServiceSpec:
@property @property
def tensor_parallel_size(self) -> int: def tensor_parallel_size(self) -> int:
"""Get tensor parallel size from vLLM arguments""" """Get tensor parallel size from vLLM arguments"""
try: args = self._get_args()
args_list = self._spec["extraPodSpec"]["mainContainer"]["args"] for i, arg in enumerate(args):
except KeyError: if arg == "--tensor-parallel-size":
return 1 # Default tensor parallel size if i + 1 < len(args) and not args[i + 1].startswith("-"):
return int(args[i + 1])
args_str = " ".join(args_list) return 1
parts = shlex.split(args_str)
for i, part in enumerate(parts):
if part == "--tensor-parallel-size":
return int(parts[i + 1]) if i + 1 < len(parts) else 1
return 1 return 1
@tensor_parallel_size.setter @tensor_parallel_size.setter
def tensor_parallel_size(self, value: int): def tensor_parallel_size(self, value: int):
if "extraPodSpec" not in self._spec: args = self._get_args()
return for i, arg in enumerate(args):
if "mainContainer" not in self._spec["extraPodSpec"]: if arg == "--tensor-parallel-size":
return if i + 1 < len(args) and not args[i + 1].startswith("-"):
args[i + 1] = str(value)
args_list = self._spec["extraPodSpec"]["mainContainer"].get("args", []) else:
args_str = " ".join(args_list) args.append(str(value))
parts = shlex.split(args_str) self.gpus = value
return
# Find existing tensor-parallel-size argument args.extend(["--tensor-parallel-size", str(value)])
tp_index = None
for i, part in enumerate(parts):
if part == "--tensor-parallel-size":
tp_index = i
break
if tp_index is not None:
# Update existing value
if tp_index + 1 < len(parts):
parts[tp_index + 1] = str(value)
else:
parts.append(str(value))
else:
# Add new argument
parts.extend(["--tensor-parallel-size", str(value)])
# Store args as a list of separate strings for proper command-line parsing
# When TP > 1, this setter is called and adds --tensor-parallel-size to args.
# WRONG: [" ".join(parts)] would create ["--model Qwen/Qwen3-0.6B --tensor-parallel-size 2"]
# causing argparse to fail with "IndexError: list index out of range"
# RIGHT: parts creates ["--model", "Qwen/Qwen3-0.6B", "--tensor-parallel-size", "2"]
self._spec["extraPodSpec"]["mainContainer"]["args"] = parts
# Auto-adjust GPU count to match tensor parallel size
self.gpus = value self.gpus = value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment