Unverified Commit bd828722 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[ci]try to fix flaky multi-step tests (#11894)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 405eb8e3
...@@ -16,7 +16,6 @@ NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps ...@@ -16,7 +16,6 @@ NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
NUM_PROMPTS = [10] NUM_PROMPTS = [10]
DEFAULT_SERVER_ARGS: List[str] = [ DEFAULT_SERVER_ARGS: List[str] = [
"--disable-log-requests",
"--worker-use-ray", "--worker-use-ray",
"--gpu-memory-utilization", "--gpu-memory-utilization",
"0.85", "0.85",
...@@ -110,7 +109,7 @@ async def test_multi_step( ...@@ -110,7 +109,7 @@ async def test_multi_step(
# Spin up client/server & issue completion API requests. # Spin up client/server & issue completion API requests.
# Default `max_wait_seconds` is 240 but was empirically # Default `max_wait_seconds` is 240 but was empirically
# was raised 3x to 720 *just for this test* due to # was raised 5x to 1200 *just for this test* due to
# observed timeouts in GHA CI # observed timeouts in GHA CI
ref_completions = await completions_with_server_args( ref_completions = await completions_with_server_args(
prompts, prompts,
......
...@@ -157,13 +157,19 @@ class RemoteOpenAIServer: ...@@ -157,13 +157,19 @@ class RemoteOpenAIServer:
def url_for(self, *parts: str) -> str: def url_for(self, *parts: str) -> str:
return self.url_root + "/" + "/".join(parts) return self.url_root + "/" + "/".join(parts)
def get_client(self): def get_client(self, **kwargs):
if "timeout" not in kwargs:
kwargs["timeout"] = 600
return openai.OpenAI( return openai.OpenAI(
base_url=self.url_for("v1"), base_url=self.url_for("v1"),
api_key=self.DUMMY_API_KEY, api_key=self.DUMMY_API_KEY,
max_retries=0,
**kwargs,
) )
def get_async_client(self, **kwargs): def get_async_client(self, **kwargs):
if "timeout" not in kwargs:
kwargs["timeout"] = 600
return openai.AsyncOpenAI(base_url=self.url_for("v1"), return openai.AsyncOpenAI(base_url=self.url_for("v1"),
api_key=self.DUMMY_API_KEY, api_key=self.DUMMY_API_KEY,
max_retries=0, max_retries=0,
...@@ -780,7 +786,6 @@ async def completions_with_server_args( ...@@ -780,7 +786,6 @@ async def completions_with_server_args(
assert len(max_tokens) == len(prompts) assert len(max_tokens) == len(prompts)
outputs = None outputs = None
max_wait_seconds = 240 * 3 # 240 is default
with RemoteOpenAIServer(model_name, with RemoteOpenAIServer(model_name,
server_cli_args, server_cli_args,
max_wait_seconds=max_wait_seconds) as server: max_wait_seconds=max_wait_seconds) as server:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment