Unverified Commit 31c78df7 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

refactor: PyTest subprocess for cancellation unit tests (#3127)


Signed-off-by: default avatarmichaelfeil <me@michaelfeil.eu>
Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
Co-authored-by: default avatarmichaelfeil <me@michaelfeil.eu>
parent 007b9d60
......@@ -22,6 +22,7 @@ pytest-asyncio
pytest-benchmark
pytest-codeblocks
pytest-cov
pytest-forked
pytest-md-report
pytest-mypy
pytest-timeout
......
......@@ -14,12 +14,12 @@
# limitations under the License.
import asyncio
import random
import string
import pytest
from dynamo._core import DistributedRuntime
from dynamo.runtime import Context
pytestmark = pytest.mark.pre_merge
class MockServer:
......@@ -124,28 +124,10 @@ class MockServer:
raise asyncio.CancelledError
def random_string(length=10):
"""Generate a random string for namespace isolation"""
# Start with a letter to satisfy Prometheus naming requirements
first_char = random.choice(string.ascii_lowercase)
remaining_chars = string.ascii_lowercase + string.digits
rest = "".join(random.choices(remaining_chars, k=length - 1))
return first_char + rest
@pytest.fixture
async def runtime():
"""Create a DistributedRuntime for testing"""
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, True)
yield runtime
runtime.shutdown()
@pytest.fixture
def namespace():
"""Generate a random namespace for test isolation"""
return random_string()
"""Namespace for this test file"""
return "cancellation_unit_test"
@pytest.fixture
......@@ -158,7 +140,6 @@ async def server(runtime, namespace):
"""Initialize the test server component and serve the generate endpoint"""
component = runtime.namespace(namespace).component("backend")
await component.create_service()
endpoint = component.endpoint("generate")
print("Started test server instance")
......@@ -191,3 +172,120 @@ async def client(runtime, namespace):
await client.wait_for_instances()
return client
@pytest.mark.forked
@pytest.mark.asyncio
async def test_client_context_cancel(server, client):
_, handler = server
context = Context()
stream = await client.generate("_generate_until_context_cancelled", context=context)
iteration_count = 0
async for annotated in stream:
number = annotated.data()
print(f"Received iteration: {number}")
# Verify received valid number
assert number == iteration_count
# Break after receiving 2 responses
if iteration_count >= 2:
print("Cancelling after 2 responses...")
context.stop_generating()
break
iteration_count += 1
# Give server a moment to process the cancellation
await asyncio.sleep(0.2)
# Verify server detected the cancellation
assert handler.context_is_stopped
assert not handler.context_is_killed
# TODO: Test with _generate_until_asyncio_cancelled server handler
@pytest.mark.forked
@pytest.mark.asyncio
async def test_client_loop_break(server, client):
_, handler = server
stream = await client.generate("_generate_until_context_cancelled")
iteration_count = 0
async for annotated in stream:
number = annotated.data()
print(f"Received iteration: {number}")
# Verify received valid number
assert number == iteration_count
# Break after receiving 2 responses
if iteration_count >= 2:
print("Cancelling after 2 responses...")
break
iteration_count += 1
# Give server a moment to process the cancellation
await asyncio.sleep(0.2)
# TODO: Implicit cancellation is not yet implemented, so the server context will not
# show any cancellation.
assert not handler.context_is_stopped
assert not handler.context_is_killed
# TODO: Test with _generate_until_asyncio_cancelled server handler
@pytest.mark.forked
@pytest.mark.asyncio
async def test_server_context_cancel(server, client):
_, handler = server
stream = await client.generate("_generate_and_cancel_context")
iteration_count = 0
try:
async for annotated in stream:
number = annotated.data()
print(f"Received iteration: {number}")
assert number == iteration_count
iteration_count += 1
assert False, "Stream completed without cancellation"
except ValueError as e:
# Verify the expected cancellation exception is received
# TODO: Should this be a asyncio.CancelledError?
assert str(e) == "Stream ended before generation completed"
# Verify server context cancellation status
assert handler.context_is_stopped
assert not handler.context_is_killed
@pytest.mark.forked
@pytest.mark.asyncio
async def test_server_raise_cancelled(server, client):
_, handler = server
stream = await client.generate("_generate_and_raise_cancelled")
iteration_count = 0
try:
async for annotated in stream:
number = annotated.data()
print(f"Received iteration: {number}")
assert number == iteration_count
iteration_count += 1
assert False, "Stream completed without cancellation"
except ValueError as e:
# Verify the expected cancellation exception is received
# TODO: Should this be a asyncio.CancelledError?
assert (
str(e)
== "a python exception was caught while processing the async generator: CancelledError: "
)
# Verify server context cancellation status
# TODO: Server to gracefully stop the stream?
assert not handler.context_is_stopped
assert not handler.context_is_killed
......@@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import subprocess
from time import sleep
import pytest
from dynamo.runtime import DistributedRuntime
@pytest.fixture(scope="module", autouse=True)
def nats_and_etcd():
......@@ -35,3 +38,16 @@ def nats_and_etcd():
nats_server.wait()
etcd.terminate()
etcd.wait()
@pytest.fixture(scope="function", autouse=False)
async def runtime():
"""
Create a DistributedRuntime for testing.
DistributedRuntime has singleton requirements, so tests using this fixture should be
marked with `@pytest.mark.forked` to run in a separate process for isolation.
"""
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, True)
yield runtime
runtime.shutdown()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import pytest
pytestmark = pytest.mark.pre_merge
def _run_test_in_subprocess(test_name: str):
"""Helper function to run a test file in a separate process"""
test_file = os.path.join(os.path.dirname(__file__), f"{test_name}.py")
result = subprocess.run(
["pytest", test_file, "-v"],
capture_output=True,
text=True,
cwd=os.path.dirname(__file__),
)
print("STDOUT:", result.stdout)
print("STDERR:", result.stderr)
print("Return code:", result.returncode)
assert (
result.returncode == 0
), f"Test {test_name} failed with return code {result.returncode}"
def test_client_context_cancel():
_run_test_in_subprocess("test_client_context_cancel")
def test_client_loop_break():
_run_test_in_subprocess("test_client_loop_break")
def test_server_context_cancel():
_run_test_in_subprocess("test_server_context_cancel")
def test_server_raise_cancelled():
_run_test_in_subprocess("test_server_raise_cancelled")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import pytest
from dynamo._core import Context
@pytest.mark.asyncio
async def test_client_context_cancel(server, client):
_, handler = server
context = Context()
stream = await client.generate("_generate_until_context_cancelled", context=context)
iteration_count = 0
async for annotated in stream:
number = annotated.data()
print(f"Received iteration: {number}")
# Verify received valid number
assert number == iteration_count
# Break after receiving 2 responses
if iteration_count >= 2:
print("Cancelling after 2 responses...")
context.stop_generating()
iteration_count += 1
# Verify we received exactly 3 responses (0, 1, 2)
assert iteration_count == 3
# Give server a moment to process the cancellation
await asyncio.sleep(0.2)
# Verify server detected the cancellation
assert handler.context_is_stopped
assert not handler.context_is_killed
# TODO: Test with _generate_until_asyncio_cancelled server handler
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import pytest
@pytest.mark.asyncio
async def test_client_loop_break(server, client):
_, handler = server
stream = await client.generate("_generate_until_context_cancelled")
iteration_count = 0
async for annotated in stream:
number = annotated.data()
print(f"Received iteration: {number}")
# Verify received valid number
assert number == iteration_count
# Break after receiving 2 responses
if iteration_count >= 2:
print("Cancelling after 2 responses...")
break
iteration_count += 1
# Give server a moment to process the cancellation
await asyncio.sleep(0.2)
# TODO: Implicit cancellation is not yet implemented, so the server context will not
# show any cancellation.
assert not handler.context_is_stopped
assert not handler.context_is_killed
# TODO: Test with _generate_until_asyncio_cancelled server handler
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
@pytest.mark.asyncio
async def test_server_context_cancel(server, client):
_, handler = server
stream = await client.generate("_generate_and_cancel_context")
iteration_count = 0
try:
async for annotated in stream:
number = annotated.data()
print(f"Received iteration: {number}")
assert number == iteration_count
iteration_count += 1
assert False, "Stream completed without cancellation"
except ValueError as e:
# Verify the expected cancellation exception is received
# TODO: Should this be a asyncio.CancelledError?
assert str(e) == "Stream ended before generation completed"
# Verify server context cancellation status
assert handler.context_is_stopped
assert not handler.context_is_killed
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
@pytest.mark.asyncio
async def test_server_raise_cancelled(server, client):
_, handler = server
stream = await client.generate("_generate_and_raise_cancelled")
iteration_count = 0
try:
async for annotated in stream:
number = annotated.data()
print(f"Received iteration: {number}")
assert number == iteration_count
iteration_count += 1
assert False, "Stream completed without cancellation"
except ValueError as e:
# Verify the expected cancellation exception is received
# TODO: Should this be a asyncio.CancelledError?
assert (
str(e)
== "a python exception was caught while processing the async generator: CancelledError: "
)
# Verify server context cancellation status
# TODO: Server to gracefully stop the stream?
assert not handler.context_is_stopped
assert not handler.context_is_killed
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment