test_example.py 4.21 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0

"""
Tests for the cancellation example in examples/custom_backend/cancellation
"""

import asyncio
import os
import subprocess

import pytest

14
15
16
17
18
pytestmark = [
    pytest.mark.gpu_0,
    pytest.mark.pre_merge,
    pytest.mark.integration,
]
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88


@pytest.fixture(scope="module")
def example_dir():
    """Path to the cancellation example directory"""
    # Get the directory of this test file
    test_dir = os.path.dirname(os.path.abspath(__file__))
    # Navigate to the cancellation example directory relative to this test
    return os.path.normpath(
        os.path.join(test_dir, "../../../../../examples/custom_backend/cancellation")
    )


@pytest.fixture(scope="function")
async def server_process(example_dir):
    """Start the backend server and clean up after test"""
    server_proc = subprocess.Popen(
        ["python3", "-u", "server.py"],
        cwd=example_dir,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
    )

    # Wait for server to start
    await asyncio.sleep(1)

    yield server_proc

    # Cleanup
    server_proc.terminate()
    server_proc.wait(timeout=1)


@pytest.fixture(scope="function")
async def middle_server_process(example_dir, server_process):
    """Start the middle server (depends on backend server) and clean up after test"""
    middle_proc = subprocess.Popen(
        ["python3", "-u", "middle_server.py"],
        cwd=example_dir,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
    )

    # Wait for middle server to start
    await asyncio.sleep(1)

    yield middle_proc

    # Cleanup
    middle_proc.terminate()
    middle_proc.wait(timeout=1)


def run_client(example_dir, use_middle=False):
    """Run the client and capture its output"""
    cmd = ["python3", "client.py"]
    if use_middle:
        cmd.append("--middle")

    client_proc = subprocess.Popen(
        cmd,
        cwd=example_dir,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
    )

    # Wait for client to complete
89
90
    stdout, _ = client_proc.communicate(timeout=2)
    print(f"Client stdout: {stdout}")
91
92
93
94

    return stdout


95
def stop_process(name, process):
96
97
98
    """Stop a running process and capture its output"""
    process.terminate()
    stdout, _ = process.communicate(timeout=1)
99
    print(f"{name}: {stdout}")
100
101
102
103
    return stdout


@pytest.mark.asyncio
104
105
106
async def test_direct_connection_cancellation(
    temp_file_store, example_dir, server_process
):
107
108
    """Test cancellation with direct client-server connection"""
    # Run the client (direct connection)
109
    print(f"Key-value store dir: {temp_file_store}")
110
111
112
113
114
115
    client_output = run_client(example_dir, use_middle=False)

    # Wait for server to print cancellation message
    await asyncio.sleep(1)

    # Capture server output
116
    server_output = stop_process("server_process", server_process)
117
118
119
120
121
122
123
124
125
126
127
128

    # Assert expected messages
    assert (
        "Client: Cancelling after 3 responses..." in client_output
    ), f"Client output: {client_output}"
    assert (
        "Server: Cancelled at iteration" in server_output
    ), f"Server output: {server_output}"


@pytest.mark.asyncio
async def test_middle_server_cancellation(
129
    temp_file_store, example_dir, server_process, middle_server_process
130
131
132
):
    """Test cancellation with middle server proxy"""
    # Run the client (through middle server)
133
    print(f"Key-value store dir: {temp_file_store}")
134
135
136
137
138
139
    client_output = run_client(example_dir, use_middle=True)

    # Wait for server to print cancellation message
    await asyncio.sleep(1)

    # Capture output from all processes
140
141
    server_output = stop_process("server_process", server_process)
    middle_output = stop_process("middle_server_process", middle_server_process)
142
143
144
145
146

    # Assert expected messages
    assert (
        "Client: Cancelling after 3 responses..." in client_output
    ), f"Client output: {client_output}"
147
148
149
    assert (
        "Middle server: Forwarding response 2" in middle_output
    ), f"Middle server output: {middle_output}"
150
151
152
    assert (
        "Server: Cancelled at iteration" in server_output
    ), f"Server output: {server_output}"