Unverified Commit 48911230 authored by Dmitry Tokarev's avatar Dmitry Tokarev Committed by GitHub
Browse files

test(fault_tolerance): replace hand-rolled SSE parser with openai SDK (#8536)


Signed-off-by: default avatarDmitry Tokarev <dtokarev@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.7 (1M context) <noreply@anthropic.com>
parent a2a2753d
......@@ -90,3 +90,4 @@ repos:
- filelock
- pyyaml
- prometheus_client>=0.23.1
- openai
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import re
import threading
......@@ -9,6 +8,7 @@ import time
import pytest
import requests
from openai import APIError, OpenAI
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.managed_process import (
......@@ -46,34 +46,18 @@ class DynamoFrontendProcess(BaseDynamoFrontendProcess):
)
def _parse_completion_sse_content(line: str) -> str | Exception | None:
"""
Parse an SSE line from the completions API and extract the text content.
Args:
line: Raw SSE line string
def _make_client(frontend_port: int) -> OpenAI:
"""Build an OpenAI client pointed at the test frontend.
Returns:
str: The text content if found
Exception: If error event or parse error
None: If no content (e.g., [DONE] or empty)
max_retries=0 so fault-tolerance tests see the first error instead of
silent retries; api_key is a placeholder since the frontend doesn't auth.
"""
if line.startswith("event: error"):
return Exception(f"SSE error event received: {line}")
if not line.startswith("data: "):
return None # Skip non-data lines
data_str = line[6:] # Remove "data: " prefix
if data_str == "[DONE]":
return None
try:
chunk = json.loads(data_str)
text = chunk["choices"][0].get("text")
return text # May be None if no text content
except Exception as e:
return Exception(f"Error parsing response chunk: {e}")
return OpenAI(
base_url=f"http://localhost:{frontend_port}/v1",
api_key="not-needed",
max_retries=0,
timeout=240,
)
def start_completion_request(
......@@ -102,14 +86,6 @@ def start_completion_request(
prompt = "Tell me a long long long story about yourself?"
if use_long_prompt:
prompt += " Make sure it is" + " long" * 8000 + "!"
timeout = 240 # Extended timeout for long request
payload = {
"model": FAULT_TOLERANCE_MODEL_NAME,
"prompt": prompt,
"stream": stream,
}
headers = {"Content-Type": "application/json"}
logger.info(
f"Sending completion request (stream={stream}) with prompt: '{prompt[:50]}...'"
......@@ -118,45 +94,30 @@ def start_completion_request(
response_list.append((None, time.time())) # start timestamp
try:
with requests.post(
f"http://localhost:{frontend_port}/v1/completions",
headers=headers,
json=payload,
timeout=timeout,
stream=stream,
) as response:
logger.info(
f"Received response with status code: {response.status_code}"
client = _make_client(frontend_port)
if stream:
for chunk in client.completions.create(
model=FAULT_TOLERANCE_MODEL_NAME,
prompt=prompt,
stream=True,
):
text = chunk.choices[0].text if chunk.choices else None
# Match the original hand-rolled parser: keep empty strings,
# drop only None. Empty chunks (e.g. the first stream frame)
# still count as a response arrival for delay measurement.
if text is not None:
response_list.append((text, time.time()))
else:
resp = client.completions.create(
model=FAULT_TOLERANCE_MODEL_NAME,
prompt=prompt,
stream=False,
)
if response.status_code != 200:
response_list.append(
(
Exception(
f"Request failed with status {response.status_code}: {response.text}"
),
time.time(),
)
)
return
if stream:
for line in response.iter_lines():
if line:
content = _parse_completion_sse_content(
line.decode("utf-8")
)
if content is not None:
response_list.append((content, time.time()))
else:
try:
content = response.json()["choices"][0]["text"]
response_list.append((content, time.time()))
except Exception as e:
response_list.append(
(Exception(f"Error parsing response: {e}"), time.time())
)
response_list.append((resp.choices[0].text, time.time()))
except Exception as e:
# openai.APIError subclasses cover HTTP non-200, mid-stream
# structured `data: {"error": {...}}` frames, connection failures,
# and timeouts. Non-openai exceptions (network, etc.) also bubble.
logger.error(f"Request failed with error: {e}")
response_list.append((e, time.time()))
......@@ -166,36 +127,6 @@ def start_completion_request(
return request_thread, response_list
def _parse_chat_completion_sse_content(line: str) -> str | Exception | None:
"""
Parse an SSE line and extract the content.
Args:
line: Raw SSE line string
Returns:
str: The content delta if found
Exception: If error event or parse error
None: If no content (e.g., [DONE] or empty delta)
"""
if line.startswith("event: error"):
return Exception(f"SSE error event received: {line}")
if not line.startswith("data: "):
return None # Skip non-data lines
data_str = line[6:] # Remove "data: " prefix
if data_str == "[DONE]":
return None
try:
chunk = json.loads(data_str)
content = chunk["choices"][0]["delta"].get("content")
return content # May be None if delta has no content
except Exception as e:
return Exception(f"Error parsing response chunk: {e}")
def start_chat_completion_request(
frontend_port: int, stream: bool, use_long_prompt: bool = False
) -> tuple:
......@@ -222,14 +153,6 @@ def start_chat_completion_request(
prompt = "Tell me a long long long story about yourself?"
if use_long_prompt:
prompt += " Make sure it is" + " long" * 8000 + "!"
timeout = 240 # Extended timeout for long request
payload = {
"model": FAULT_TOLERANCE_MODEL_NAME,
"messages": [{"role": "user", "content": prompt}],
"stream": stream,
}
headers = {"Content-Type": "application/json"}
logger.info(
f"Sending chat completion request (stream={stream}) with prompt: '{prompt[:50]}...'"
......@@ -238,45 +161,31 @@ def start_chat_completion_request(
response_list.append((None, time.time())) # start timestamp
try:
with requests.post(
f"http://localhost:{frontend_port}/v1/chat/completions",
headers=headers,
json=payload,
timeout=timeout,
stream=stream,
) as response:
logger.info(
f"Received response with status code: {response.status_code}"
)
if response.status_code != 200:
response_list.append(
(
Exception(
f"Request failed with status {response.status_code}: {response.text}"
),
time.time(),
)
)
return
if stream:
for line in response.iter_lines():
if line:
content = _parse_chat_completion_sse_content(
line.decode("utf-8")
)
if content is not None:
response_list.append((content, time.time()))
else:
try:
content = response.json()["choices"][0]["message"]["content"]
client = _make_client(frontend_port)
if stream:
for chunk in client.chat.completions.create(
model=FAULT_TOLERANCE_MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
stream=True,
):
content = chunk.choices[0].delta.content if chunk.choices else None
# Match the original hand-rolled parser: keep empty strings,
# drop only None. Empty chunks (e.g. the first `role`-only
# stream frame) still count as a response arrival for delay
# measurement.
if content is not None:
response_list.append((content, time.time()))
except Exception as e:
response_list.append(
(Exception(f"Error parsing response: {e}"), time.time())
)
else:
resp = client.chat.completions.create(
model=FAULT_TOLERANCE_MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
stream=False,
)
response_list.append((resp.choices[0].message.content, time.time()))
except Exception as e:
# openai.APIError subclasses cover HTTP non-200, mid-stream
# structured `data: {"error": {...}}` frames, connection failures,
# and timeouts. Non-openai exceptions also bubble for visibility.
logger.error(f"Request failed with error: {e}")
response_list.append((e, time.time()))
......@@ -634,12 +543,12 @@ def run_migration_test(
pytest.fail(
"Request succeeded unexpectedly when migration should have failed"
)
except Exception as e:
error_str = str(e)
assert (
"SSE error event received:" in error_str
or "Request failed with status" in error_str
), f"Unexpected error: {e}"
except APIError as e:
# Expected: openai.APIError covers mid-stream structured error
# frames (DIS-1768 contract) and HTTP non-200 responses. A typed
# check is more robust than matching the exception's stringified
# message against a specific wire-format prefix.
logger.info(f"Got expected APIError: {e}")
try:
verify_migration_occurred(frontend)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment