Unverified Commit 6945d47e authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

fix(ft-tests): avoid shlex.split corrupting JSON args in ServiceSpec setters (#7012)


Signed-off-by: default avatarTzu-Ling <tzulingk@nvidia.com>
parent 419e936a
......@@ -6,7 +6,6 @@ import logging
import os
import re
import secrets
import shlex
import time
from dataclasses import dataclass, field
from typing import Any, List, Optional
......@@ -76,6 +75,25 @@ class ServiceSpec:
def envs(self, value: list[dict[str, str]]):
self._spec["envs"] = value
def _get_args(self) -> list[str]:
"""Return the container args list, normalising scalar strings to a list in-place.
Always returns the same list object that is stored in the spec, so
in-place mutations (append / index assignment) are reflected immediately
without an explicit writeback.
"""
try:
container = self._spec["extraPodSpec"]["mainContainer"]
except KeyError:
return []
if "args" not in container:
container["args"] = []
args = container["args"]
if isinstance(args, str):
args = args.split()
container["args"] = args
return args
# ----- Replicas -----
@property
def replicas(self) -> int:
......@@ -88,48 +106,22 @@ class ServiceSpec:
@property
def model(self) -> Optional[str]:
"""Model being served by this service (checks both --model and --model-path)"""
try:
args_list = self._spec["extraPodSpec"]["mainContainer"]["args"]
except KeyError:
return None
args_str = " ".join(args_list)
parts = shlex.split(args_str)
for i, part in enumerate(parts):
if part in ["--model", "--model-path"]:
return parts[i + 1] if i + 1 < len(parts) else None
args = self._get_args()
for i, arg in enumerate(args):
if arg in ["--model", "--model-path"]:
if i + 1 < len(args) and not args[i + 1].startswith("-"):
return args[i + 1]
return None
@model.setter
def model(self, value: str):
if "extraPodSpec" not in self._spec:
return
if "mainContainer" not in self._spec["extraPodSpec"]:
return
args_list = self._spec["extraPodSpec"]["mainContainer"].get("args", [])
args_str = " ".join(args_list)
parts = shlex.split(args_str)
# Try to update --model first, then --model-path
model_index = None
for i, part in enumerate(parts):
if part in ["--model", "--model-path"]:
model_index = i
break
if model_index is not None:
if model_index + 1 < len(parts):
parts[model_index + 1] = value
else:
return
else:
args = self._get_args()
for i, arg in enumerate(args):
if arg in ["--model", "--model-path"]:
if i + 1 < len(args) and not args[i + 1].startswith("-"):
args[i + 1] = value
return
# Store args as a list of separate strings for proper command-line parsing
# WRONG: [" ".join(parts)] creates ["--model Qwen/Qwen3-0.6B"] (single string)
# RIGHT: parts creates ["--model", "Qwen/Qwen3-0.6B"] (separate strings)
self._spec["extraPodSpec"]["mainContainer"]["args"] = parts
# ----- GPUs -----
@property
def gpus(self) -> int:
......@@ -149,54 +141,26 @@ class ServiceSpec:
@property
def tensor_parallel_size(self) -> int:
"""Get tensor parallel size from vLLM arguments"""
try:
args_list = self._spec["extraPodSpec"]["mainContainer"]["args"]
except KeyError:
return 1 # Default tensor parallel size
args_str = " ".join(args_list)
parts = shlex.split(args_str)
for i, part in enumerate(parts):
if part == "--tensor-parallel-size":
return int(parts[i + 1]) if i + 1 < len(parts) else 1
args = self._get_args()
for i, arg in enumerate(args):
if arg == "--tensor-parallel-size":
if i + 1 < len(args) and not args[i + 1].startswith("-"):
return int(args[i + 1])
return 1
return 1
@tensor_parallel_size.setter
def tensor_parallel_size(self, value: int):
if "extraPodSpec" not in self._spec:
return
if "mainContainer" not in self._spec["extraPodSpec"]:
return
args_list = self._spec["extraPodSpec"]["mainContainer"].get("args", [])
args_str = " ".join(args_list)
parts = shlex.split(args_str)
# Find existing tensor-parallel-size argument
tp_index = None
for i, part in enumerate(parts):
if part == "--tensor-parallel-size":
tp_index = i
break
if tp_index is not None:
# Update existing value
if tp_index + 1 < len(parts):
parts[tp_index + 1] = str(value)
args = self._get_args()
for i, arg in enumerate(args):
if arg == "--tensor-parallel-size":
if i + 1 < len(args) and not args[i + 1].startswith("-"):
args[i + 1] = str(value)
else:
parts.append(str(value))
else:
# Add new argument
parts.extend(["--tensor-parallel-size", str(value)])
# Store args as a list of separate strings for proper command-line parsing
# When TP > 1, this setter is called and adds --tensor-parallel-size to args.
# WRONG: [" ".join(parts)] would create ["--model Qwen/Qwen3-0.6B --tensor-parallel-size 2"]
# causing argparse to fail with "IndexError: list index out of range"
# RIGHT: parts creates ["--model", "Qwen/Qwen3-0.6B", "--tensor-parallel-size", "2"]
self._spec["extraPodSpec"]["mainContainer"]["args"] = parts
# Auto-adjust GPU count to match tensor parallel size
args.append(str(value))
self.gpus = value
return
args.extend(["--tensor-parallel-size", str(value)])
self.gpus = value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment