Unverified Commit 9b9f2ce4 authored by Alec's avatar Alec Committed by GitHub
Browse files

fix: pytest robustness and parsing error (#2676)

parent 80279ad3
......@@ -28,7 +28,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
import dynamo.nixl_connect as connect
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
......@@ -56,8 +56,13 @@ CACHE_SIZE_MAXIMUM = 8
class VllmEncodeWorker:
def __init__(self, args: argparse.Namespace, engine_args: AsyncEngineArgs) -> None:
self.downstream_endpoint = args.downstream_endpoint
def __init__(
self,
args: argparse.Namespace,
engine_args: AsyncEngineArgs,
pd_worker_client: Client,
) -> None:
self.pd_worker_client = pd_worker_client
self.engine_args = engine_args
self.model = self.engine_args.model
......@@ -178,16 +183,6 @@ class VllmEncodeWorker:
async def async_init(self, runtime: DistributedRuntime):
logger.info("Startup started.")
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
self.downstream_endpoint
)
self.pd_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
......@@ -262,9 +257,22 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
generate_endpoint = component.endpoint(config.endpoint)
handler = VllmEncodeWorker(args, config.engine_args)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint
)
pd_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
handler = VllmEncodeWorker(args, config.engine_args, pd_worker_client)
await handler.async_init(runtime)
logger.info("Waiting for PD Worker Instances ...")
await pd_worker_client.wait_for_instances()
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
try:
......
......@@ -33,7 +33,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import FlexibleArgumentParser
from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
# To import example local module
......@@ -96,9 +96,14 @@ class Processor(ProcessMixIn):
return args, config
def __init__(self, args: argparse.Namespace, engine_args: AsyncEngineArgs):
def __init__(
self,
args: argparse.Namespace,
engine_args: AsyncEngineArgs,
encode_worker_client: Client,
):
self.encode_worker_client = encode_worker_client
self.prompt_template = args.prompt_template
self.downstream_endpoint = args.downstream_endpoint
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
......@@ -125,17 +130,6 @@ class Processor(ProcessMixIn):
)
return base_tokenizer
async def async_init(self, runtime: DistributedRuntime):
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
self.downstream_endpoint
)
self.encode_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
self,
......@@ -300,8 +294,20 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
generate_endpoint = component.endpoint(config.endpoint)
handler = Processor(args, config.engine_args)
await handler.async_init(runtime)
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
args.downstream_endpoint
)
encode_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
handler = Processor(args, config.engine_args, encode_worker_client)
logger.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
# Register the endpoint as entrypoint to a model
await register_llm(
......
......@@ -246,6 +246,24 @@ class Client:
...
def instance_ids(self) -> List[int]:
"""
Get list of current instance IDs.
Returns:
A list of currently available instance IDs
"""
...
async def wait_for_instances(self) -> List[int]:
"""
Wait for instances to be available for work and return their IDs.
Returns:
A list of instance IDs that are available for work
"""
...
async def random(self, request: JsonLike) -> AsyncIterator[JsonLike]:
"""
Pick a random instance of the endpoint and issue the request
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Common base classes and utilities for engine tests (vLLM, TRT-LLM, etc.)"""
from dataclasses import dataclass
from typing import Any, Callable, List
from tests.utils.deployment_graph import Payload
# Common text prompt used across tests
TEXT_PROMPT = "Tell me a short joke about AI."
@dataclass
class EngineConfig:
"""Base configuration for engine test scenarios"""
name: str
directory: str
script_name: str
marks: List[Any]
endpoints: List[str]
response_handlers: List[Callable[[Any], str]]
model: str
timeout: int = 120
delayed_start: int = 0
def create_payload_for_config(config: EngineConfig) -> Payload:
"""Create a standard payload using the model from the engine config.
This provides the default implementation for text-only models.
"""
return Payload(
payload_chat={
"model": config.model,
"messages": [
{
"role": "user",
"content": TEXT_PROMPT,
}
],
"max_tokens": 150,
"temperature": 0.1,
"stream": False,
},
payload_completions={
"model": config.model,
"prompt": TEXT_PROMPT,
"max_tokens": 150,
"temperature": 0.1,
"stream": False,
},
repeat_count=3,
expected_log=[],
expected_response=["AI"],
)
......@@ -5,66 +5,27 @@ import logging
import os
import time
from dataclasses import dataclass
from typing import Any, Callable, List
import pytest
import requests
from tests.serve.common import EngineConfig, create_payload_for_config
from tests.utils.deployment_graph import (
Payload,
chat_completions_response_handler,
completions_response_handler,
)
from tests.utils.managed_process import ManagedProcess
from tests.utils.engine_process import EngineProcess
logger = logging.getLogger(__name__)
text_prompt = "Tell me a short joke about AI."
def create_payload_for_config(config: "TRTLLMConfig") -> Payload:
"""Create a payload using the model from the trtllm config"""
return Payload(
payload_chat={
"model": config.model,
"messages": [
{
"role": "user",
"content": text_prompt,
}
],
"max_tokens": 150,
"temperature": 0.1,
},
payload_completions={
"model": config.model,
"prompt": text_prompt,
"max_tokens": 150,
"temperature": 0.1,
},
repeat_count=1,
expected_log=[],
expected_response=["AI"],
)
# TODO: Unify with vllm/sglang tests to reduce code duplication
@dataclass
class TRTLLMConfig:
class TRTLLMConfig(EngineConfig):
"""Configuration for trtllm test scenarios"""
name: str
directory: str
script_name: str
marks: List[Any]
endpoints: List[str]
response_handlers: List[Callable[[Any], str]]
model: str
timeout: int = 60
delayed_start: int = 0
class TRTLLMProcess(ManagedProcess):
class TRTLLMProcess(EngineProcess):
"""Simple process manager for trtllm shell scripts"""
def __init__(self, config: TRTLLMConfig, request):
......@@ -97,87 +58,6 @@ class TRTLLMProcess(ManagedProcess):
log_dir=request.node.name,
)
def _check_models_api(self, response):
"""Check if models API is working and returns models"""
try:
if response.status_code != 200:
return False
data = response.json()
return data.get("data") and len(data["data"]) > 0
except Exception:
return False
def _check_url(self, url, timeout=30, sleep=2.0):
"""Override to use a more reasonable retry interval"""
return super()._check_url(url, timeout, sleep)
def check_response(
self, payload, response, response_handler, logger=logging.getLogger()
):
assert response.status_code == 200, "Response Error"
content = response_handler(response)
logger.info(f"Received Content: {content}")
# Check for expected responses
assert content, "Empty response content"
for expected in payload.expected_response:
assert expected in content, f"Expected '{expected}' not found in response"
def wait_for_ready(self, payload, logger=logging.getLogger()):
url = f"http://localhost:{self.port}/{self.config.endpoints[0]}"
start_time = time.time()
retry_delay = 5
elapsed = 0.0
logger.info("Waiting for Deployment Ready")
json_payload = (
payload.payload_chat
if self.config.endpoints[0] == "v1/chat/completions"
else payload.payload_completions
)
while (elapsed := time.time() - start_time) < self.config.timeout:
try:
response = requests.post(
url,
json=json_payload,
timeout=self.config.timeout - elapsed,
)
except (requests.RequestException, requests.Timeout) as e:
logger.warning(f"Retrying due to Request failed: {e}")
time.sleep(retry_delay)
continue
logger.info(f"Response: {response}")
if response.status_code == 500:
error = response.json().get("error", "")
if "no instances" in error:
logger.warning(
f"Retrying due to no instances available for model '{self.config.model}'"
)
time.sleep(retry_delay)
continue
if response.status_code == 404:
error = response.json().get("error", "")
if "Model not found" in error:
logger.warning(
f"Retrying due to model not found for model '{self.config.model}'"
)
time.sleep(retry_delay)
continue
# Process the response
if response.status_code != 200:
pytest.fail(
f"Service returned status code {response.status_code}: {response.text}"
)
else:
break
else:
pytest.fail(
f"Service did not return a successful response within {self.config.timeout} s"
)
self.check_response(payload, response, self.config.response_handlers[0], logger)
logger.info("Deployment Ready")
# trtllm test configurations
trtllm_configs = {
......@@ -192,7 +72,8 @@ trtllm_configs = {
completions_response_handler,
],
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
delayed_start=60,
delayed_start=0,
timeout=360,
),
"disaggregated": TRTLLMConfig(
name="disaggregated",
......@@ -205,7 +86,8 @@ trtllm_configs = {
completions_response_handler,
],
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
delayed_start=60,
delayed_start=0,
timeout=360,
),
# TODO: These are sanity tests that the kv router examples launch
# and inference without error, but do not do detailed checks on the
......@@ -221,7 +103,8 @@ trtllm_configs = {
completions_response_handler,
],
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
delayed_start=60,
delayed_start=0,
timeout=360,
),
"disaggregated_router": TRTLLMConfig(
name="disaggregated_router",
......@@ -234,7 +117,8 @@ trtllm_configs = {
completions_response_handler,
],
model="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
delayed_start=60,
delayed_start=0,
timeout=360,
),
}
......@@ -269,8 +153,6 @@ def test_deployment(trtllm_config_test, request, runtime_services):
logger.info(f"Script: {config.script_name}")
with TRTLLMProcess(config, request) as server_process:
server_process.wait_for_ready(payload, logger)
assert len(config.endpoints) == len(config.response_handlers)
for endpoint, response_handler in zip(
config.endpoints, config.response_handlers
......@@ -288,11 +170,7 @@ def test_deployment(trtllm_config_test, request, runtime_services):
for _ in range(payload.repeat_count):
elapsed = time.time() - start_time
response = requests.post(
url,
json=request_body,
timeout=config.timeout - elapsed,
)
server_process.check_response(
payload, response, response_handler, logger
response = server_process.send_request(
url, payload=request_body, timeout=config.timeout - elapsed
)
server_process.check_response(payload, response, response_handler)
......@@ -5,26 +5,26 @@ import logging
import os
import time
from dataclasses import dataclass
from typing import Any, Callable, List, Optional
from typing import List, Optional
import pytest
import requests
from tests.serve.common import EngineConfig
from tests.serve.common import create_payload_for_config as base_create_payload
from tests.utils.deployment_graph import (
Payload,
chat_completions_response_handler,
completions_response_handler,
)
from tests.utils.managed_process import ManagedProcess
from tests.utils.engine_process import EngineProcess
logger = logging.getLogger(__name__)
text_prompt = "Tell me a short joke about AI."
def create_payload_for_config(config: "VLLMConfig") -> Payload:
"""Create a payload using the model from the vLLM config"""
if "multimodal" in config.name:
# Special handling for multimodal models
return Payload(
payload_chat={
"model": config.model,
......@@ -51,47 +51,18 @@ def create_payload_for_config(config: "VLLMConfig") -> Payload:
expected_response=["bus"],
)
else:
return Payload(
payload_chat={
"model": config.model,
"messages": [
{
"role": "user",
"content": text_prompt,
}
],
"max_tokens": 150,
"temperature": 0.1,
},
payload_completions={
"model": config.model,
"prompt": text_prompt,
"max_tokens": 150,
"temperature": 0.1,
},
repeat_count=1,
expected_log=[],
expected_response=["AI"],
)
# Use base implementation for standard text models
return base_create_payload(config)
@dataclass
class VLLMConfig:
class VLLMConfig(EngineConfig):
"""Configuration for vLLM test scenarios"""
name: str
directory: str
script_name: str
marks: List[Any]
endpoints: List[str]
response_handlers: List[Callable[[Any], str]]
model: str
timeout: int = 120
delayed_start: int = 0
args: Optional[List[str]] = None
class VLLMProcess(ManagedProcess):
class VLLMProcess(EngineProcess):
"""Simple process manager for vllm shell scripts"""
def __init__(self, config: VLLMConfig, request):
......@@ -122,102 +93,6 @@ class VLLMProcess(ManagedProcess):
log_dir=request.node.name,
)
def _check_models_api(self, response):
"""Check if models API is working and returns models"""
try:
if response.status_code != 200:
return False
data = response.json()
return data.get("data") and len(data["data"]) > 0
except Exception:
return False
def _check_url(self, url, timeout=30, sleep=2.0):
"""Override to use a more reasonable retry interval"""
return super()._check_url(url, timeout, sleep)
def check_response(
self, payload, response, response_handler, logger=logging.getLogger()
):
assert response.status_code == 200, "Response Error"
content = response_handler(response)
logger.info("Received Content: %s", content)
# Check for expected responses
assert content, "Empty response content"
for expected in payload.expected_response:
assert expected in content, "Expected '%s' not found in response" % expected
def wait_for_ready(self, payload, logger=logging.getLogger()):
url = f"http://localhost:{self.port}/{self.config.endpoints[0]}"
start_time = time.time()
retry_delay = 5
elapsed = 0.0
logger.info("Waiting for Deployment Ready")
json_payload = (
payload.payload_chat
if self.config.endpoints[0] == "v1/chat/completions"
else payload.payload_completions
)
while time.time() - start_time < self.config.timeout:
elapsed = time.time() - start_time
try:
response = requests.post(
url,
json=json_payload,
timeout=self.config.timeout - elapsed,
)
except (requests.RequestException, requests.Timeout) as e:
logger.warning("Retrying due to Request failed: %s", e)
time.sleep(retry_delay)
continue
logger.info("Response%r", response)
if response.status_code == 500:
error = response.json().get("error", "")
if "no instances" in error:
logger.warning("Retrying due to no instances available")
time.sleep(retry_delay)
continue
elif (
"multimodal" in self.config.name
and "Failed to fold chat completions stream" in error
):
logger.warning("Retrying due to endpoint not ready for multimodal")
time.sleep(retry_delay)
continue
if response.status_code == 404:
error = response.json().get("error", "")
if "Model not found" in error:
logger.warning("Retrying due to model not found")
time.sleep(retry_delay)
continue
# Process the response
if response.status_code != 200:
logger.error(
"Service returned status code %s: %s",
response.status_code,
response.text,
)
pytest.fail(
"Service returned status code %s: %s"
% (response.status_code, response.text)
)
else:
break
else:
logger.error(
"Service did not return a successful response within %s s",
self.config.timeout,
)
pytest.fail(
"Service did not return a successful response within %s s"
% self.config.timeout
)
self.check_response(payload, response, self.config.response_handlers[0], logger)
logger.info("Deployment Ready")
# vLLM test configurations
vllm_configs = {
......@@ -232,7 +107,8 @@ vllm_configs = {
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
delayed_start=45,
delayed_start=0,
timeout=360,
),
"agg-router": VLLMConfig(
name="agg-router",
......@@ -245,7 +121,8 @@ vllm_configs = {
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
delayed_start=45,
delayed_start=0,
timeout=360,
),
"disaggregated": VLLMConfig(
name="disaggregated",
......@@ -258,7 +135,8 @@ vllm_configs = {
completions_response_handler,
],
model="Qwen/Qwen3-0.6B",
delayed_start=45,
delayed_start=0,
timeout=360,
),
"deepep": VLLMConfig(
name="deepep",
......@@ -275,7 +153,7 @@ vllm_configs = {
completions_response_handler,
],
model="deepseek-ai/DeepSeek-V2-Lite",
delayed_start=45,
delayed_start=0,
args=[
"--model",
"deepseek-ai/DeepSeek-V2-Lite",
......@@ -286,7 +164,7 @@ vllm_configs = {
"--gpus-per-node",
"2",
],
timeout=500,
timeout=560,
),
"multimodal_agg": VLLMConfig(
name="multimodal_agg",
......@@ -298,8 +176,9 @@ vllm_configs = {
chat_completions_response_handler,
],
model="llava-hf/llava-1.5-7b-hf",
delayed_start=45,
delayed_start=0,
args=["--model", "llava-hf/llava-1.5-7b-hf"],
timeout=360,
),
# TODO: Enable this test case when we have 4 GPUs runners.
# "multimodal_disagg": VLLMConfig(
......@@ -348,8 +227,6 @@ def test_serve_deployment(vllm_config_test, request, runtime_services):
logger.info("Script: %s", config.script_name)
with VLLMProcess(config, request) as server_process:
server_process.wait_for_ready(payload, logger)
for endpoint, response_handler in zip(
config.endpoints, config.response_handlers
):
......@@ -366,11 +243,7 @@ def test_serve_deployment(vllm_config_test, request, runtime_services):
for _ in range(payload.repeat_count):
elapsed = time.time() - start_time
response = requests.post(
url,
json=request_body,
timeout=config.timeout - elapsed,
)
server_process.check_response(
payload, response, response_handler, logger
response = server_process.send_request(
url, payload=request_body, timeout=config.timeout - elapsed
)
server_process.check_response(payload, response, response_handler)
......@@ -40,8 +40,41 @@ def chat_completions_response_handler(response):
assert "choices" in result, "Missing 'choices' in response"
assert len(result["choices"]) > 0, "Empty choices in response"
assert "message" in result["choices"][0], "Missing 'message' in first choice"
assert "content" in result["choices"][0]["message"], "Missing 'content' in message"
return result["choices"][0]["message"]["content"]
message = result["choices"][0]["message"]
# Check for content in all possible fields where parsers might put output:
# 1. content - standard message content
# 2. reasoning_content - for models with reasoning parsers
# 3. refusal - when the model refuses to answer
# 4. tool_calls - for function/tool calling responses
content = message.get("content", "")
reasoning_content = message.get("reasoning_content", "")
refusal = message.get("refusal", "")
# Check for tool calls
tool_calls = message.get("tool_calls", [])
tool_content = ""
if tool_calls:
# Extract content from tool calls
tool_content = ", ".join(
call.get("function", {}).get("arguments", "")
for call in tool_calls
if call.get("function", {}).get("arguments")
)
# Return the first non-empty field in priority order
for field_content in [content, reasoning_content, refusal, tool_content]:
if field_content:
return field_content
# If all fields are empty, provide a detailed error
raise ValueError(
"All possible content fields are empty in message. "
f"Checked: content={repr(content)}, reasoning_content={repr(reasoning_content)}, "
f"refusal={repr(refusal)}, tool_calls={tool_calls}"
)
def completions_response_handler(response):
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import time
from typing import Any, Callable, Dict
import requests
from tests.utils.managed_process import ManagedProcess
logger = logging.getLogger(__name__)
class EngineResponseError(Exception):
"""Custom exception for engine response errors"""
pass
class EngineProcess(ManagedProcess):
"""Base class for LLM engine processes (vLLM, TRT-LLM, etc.)"""
def _check_models_api(self, response):
"""Check if models API is working and returns models"""
try:
if response.status_code != 200:
return False
data = response.json()
return data.get("data") and len(data["data"]) > 0
except Exception:
return False
def send_request(
self, url: str, payload: Dict[str, Any], timeout: float = 30.0
) -> requests.Response:
"""
Send a POST request to the engine with detailed logging.
Args:
url: The endpoint URL
payload: The request payload
timeout: Request timeout in seconds
Returns:
The response object
Raises:
requests.RequestException: If the request fails
"""
# Log the request as a curl command for easy reproduction
payload_json = json.dumps(payload, indent=2)
curl_command = f'curl -X POST "{url}" \\\n -H "Content-Type: application/json" \\\n -d \'{payload_json}\''
logger.info("Sending request (curl equivalent):\n%s", curl_command)
start_time = time.time()
try:
response = requests.post(url, json=payload, timeout=timeout)
elapsed = time.time() - start_time
# Log response details
logger.info(
"Received response: status=%d, elapsed=%.2fs",
response.status_code,
elapsed,
)
logger.debug("Response headers: %s", dict(response.headers))
# Try to log response body (truncated if too long)
try:
if response.headers.get("content-type", "").startswith(
"application/json"
):
response_data = response.json()
response_str = json.dumps(response_data, indent=2)
if len(response_str) > 1000:
response_str = response_str[:1000] + "... (truncated)"
logger.debug("Response body: %s", response_str)
else:
response_text = response.text
if len(response_text) > 1000:
response_text = response_text[:1000] + "... (truncated)"
logger.debug("Response body: %s", response_text)
except Exception as e:
logger.debug("Could not parse response body: %s", e)
return response
except requests.exceptions.Timeout:
logger.error("Request timed out after %.2f seconds", timeout)
raise
except requests.exceptions.ConnectionError as e:
logger.error("Connection error: %s", e)
raise
except requests.exceptions.RequestException as e:
logger.error("Request failed: %s", e)
raise
def check_response(
self,
payload: Any,
response: requests.Response,
response_handler: Callable[[Any], str],
) -> None:
"""
Check if the response is valid and contains expected content.
Args:
payload: The original payload (should have expected_response attribute)
response: The response object
response_handler: Function to extract content from response
Raises:
EngineResponseError: If the response is invalid or missing expected content
"""
if response.status_code != 200:
logger.error(
"Response returned non-200 status code: %d", response.status_code
)
error_msg = f"Response returned non-200 status code: {response.status_code}"
try:
error_data = response.json()
if "error" in error_data:
error_msg += f"\nError details: {error_data['error']}"
logger.error(
"Response error details: %s", json.dumps(error_data, indent=2)
)
except Exception:
logger.error("Response text: %s", response.text[:500])
raise EngineResponseError(error_msg)
# Extract content using the handler
try:
content = response_handler(response)
logger.info(
"Extracted content: \n%s",
content[:200] + "..." if len(content) > 200 else content,
)
except Exception as e:
raise EngineResponseError(f"Failed to extract content from response: {e}")
if not content:
raise EngineResponseError("Response contained empty content")
if hasattr(payload, "expected_response") and payload.expected_response:
missing_expected = []
for expected in payload.expected_response:
if expected not in content:
missing_expected.append(expected)
if missing_expected:
raise EngineResponseError(
f"Expected content not found in response. Missing: {missing_expected}"
)
else:
logger.info(
f"SUCCESS: All expected content ({payload.expected_response}) found in response"
)
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import shutil
......@@ -204,6 +205,39 @@ class ManagedProcess:
except (OSError, IOError) as e:
self._logger.warning("Warning: Failed to remove directory %s: %s", path, e)
def _log_tail_on_error(self, lines=20):
"""Print the last few lines of the log file when process dies."""
if self._log_path and os.path.exists(self._log_path):
try:
with open(self._log_path, "r") as f:
log_lines = f.readlines()
if log_lines:
self._logger.error(
"=== Last %d lines from %s ===",
min(lines, len(log_lines)),
self._log_path,
)
for line in log_lines[-lines:]:
self._logger.error(line.rstrip())
self._logger.error("=== End of log tail ===")
except Exception as e:
self._logger.warning("Could not read log file: %s", e)
def _check_process_alive(self, context=""):
"""Check if the main process is still alive. Raises RuntimeError if dead."""
if self.proc and self.proc.poll() is not None:
returncode = self.proc.returncode
self._logger.error(
"Main server process died with exit code %d%s",
returncode,
f" {context}" if context else "",
)
# Try to get last few lines from log for debugging
self._log_tail_on_error()
raise RuntimeError(
f"Main server process exited with code {returncode}{f' {context}' if context else ''}"
)
def _check_ports(self, timeout):
elapsed = 0.0
for port in self.health_check_ports:
......@@ -216,6 +250,9 @@ class ManagedProcess:
self._logger.info("Checking Port: %s", port)
elapsed = 0.0
while elapsed < timeout:
# Check if the main process is still alive
self._check_process_alive(f"while waiting for port {port}")
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(("localhost", port)) == 0:
self._logger.info("SUCCESS: Check Port: %s", port)
......@@ -231,7 +268,7 @@ class ManagedProcess:
elapsed += self._check_url(url, timeout - elapsed)
return elapsed
def _check_url(self, url, timeout=30, sleep=0.1):
def _check_url(self, url, timeout=30, sleep=1, log_interval=10):
if isinstance(url, tuple):
response_check = url[1]
url = url[0]
......@@ -240,19 +277,71 @@ class ManagedProcess:
start_time = time.time()
self._logger.info("Checking URL %s", url)
elapsed = 0.0
attempt = 0
last_log_time = 0.0
while elapsed < timeout:
self._check_process_alive("while waiting for health check")
attempt += 1
check_failed = False
failure_reason = None
try:
response = requests.get(url, timeout=timeout - elapsed)
if response.status_code == 200:
if response_check is None or response_check(response):
self._logger.info("SUCCESS: Check URL: %s", url)
# Try to format JSON response nicely, otherwise show raw text
try:
response_data = response.json()
response_str = json.dumps(response_data, indent=2)
self._logger.info(
"SUCCESS: Check URL: %s (attempt=%d, elapsed=%.1fs)\nResponse:\n%s",
url,
attempt,
elapsed,
response_str,
)
except (json.JSONDecodeError, Exception):
# If not JSON or any error, show raw text (truncated if too long)
response_text = response.text
if len(response_text) > 500:
response_text = response_text[:500] + "... (truncated)"
self._logger.info(
"SUCCESS: Check URL: %s (attempt=%d, elapsed=%.1fs)\nResponse: %s",
url,
attempt,
elapsed,
response_text,
)
return time.time() - start_time
else:
check_failed = True
failure_reason = "custom check failed"
else:
check_failed = True
failure_reason = f"status code {response.status_code}"
except requests.RequestException as e:
self._logger.warning("URL check failed: %s", e)
check_failed = True
failure_reason = f"request exception: {e}"
# Log progress every log_interval seconds for any failure
if check_failed and elapsed - last_log_time >= log_interval:
self._logger.info(
"Still waiting for URL %s (%s) (attempt=%d, elapsed=%.1fs)",
url,
failure_reason,
attempt,
elapsed,
)
last_log_time = elapsed
time.sleep(sleep)
elapsed = time.time() - start_time
self._logger.error("FAILED: Check URL: %s", url)
self._logger.error(
"FAILED: Check URL: %s (attempts=%d, elapsed=%.1fs)", url, attempt, elapsed
)
raise RuntimeError("FAILED: Check URL: %s" % url)
def _terminate_existing(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment