test_example.py 4.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Tests for the cancellation example in examples/custom_backend/cancellation
"""

import asyncio
import os
import subprocess

import pytest

pytestmark = pytest.mark.pre_merge


@pytest.fixture(scope="module")
def example_dir():
    """Path to the cancellation example directory"""
    # Get the directory of this test file
    test_dir = os.path.dirname(os.path.abspath(__file__))
    # Navigate to the cancellation example directory relative to this test
    return os.path.normpath(
        os.path.join(test_dir, "../../../../../examples/custom_backend/cancellation")
    )


@pytest.fixture(scope="function")
async def server_process(example_dir):
    """Start the backend server and clean up after test"""
    server_proc = subprocess.Popen(
        ["python3", "-u", "server.py"],
        cwd=example_dir,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
    )

    # Wait for server to start
    await asyncio.sleep(1)

    yield server_proc

    # Cleanup
    server_proc.terminate()
    server_proc.wait(timeout=1)


@pytest.fixture(scope="function")
async def middle_server_process(example_dir, server_process):
    """Start the middle server (depends on backend server) and clean up after test"""
    middle_proc = subprocess.Popen(
        ["python3", "-u", "middle_server.py"],
        cwd=example_dir,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
    )

    # Wait for middle server to start
    await asyncio.sleep(1)

    yield middle_proc

    # Cleanup
    middle_proc.terminate()
    middle_proc.wait(timeout=1)


def run_client(example_dir, use_middle=False):
    """Run the client and capture its output"""
    cmd = ["python3", "client.py"]
    if use_middle:
        cmd.append("--middle")

    client_proc = subprocess.Popen(
        cmd,
        cwd=example_dir,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
    )

    # Wait for client to complete
85
86
    stdout, _ = client_proc.communicate(timeout=2)
    print(f"Client stdout: {stdout}")
87
88
89
90

    return stdout


91
def stop_process(name, process):
92
93
94
    """Stop a running process and capture its output"""
    process.terminate()
    stdout, _ = process.communicate(timeout=1)
95
    print(f"{name}: {stdout}")
96
97
98
99
    return stdout


@pytest.mark.asyncio
100
101
102
async def test_direct_connection_cancellation(
    temp_file_store, example_dir, server_process
):
103
104
    """Test cancellation with direct client-server connection"""
    # Run the client (direct connection)
105
    print(f"Key-value store dir: {temp_file_store}")
106
107
108
109
110
111
    client_output = run_client(example_dir, use_middle=False)

    # Wait for server to print cancellation message
    await asyncio.sleep(1)

    # Capture server output
112
    server_output = stop_process("server_process", server_process)
113
114
115
116
117
118
119
120
121
122
123
124

    # Assert expected messages
    assert (
        "Client: Cancelling after 3 responses..." in client_output
    ), f"Client output: {client_output}"
    assert (
        "Server: Cancelled at iteration" in server_output
    ), f"Server output: {server_output}"


@pytest.mark.asyncio
async def test_middle_server_cancellation(
125
    temp_file_store, example_dir, server_process, middle_server_process
126
127
128
):
    """Test cancellation with middle server proxy"""
    # Run the client (through middle server)
129
    print(f"Key-value store dir: {temp_file_store}")
130
131
132
133
134
135
    client_output = run_client(example_dir, use_middle=True)

    # Wait for server to print cancellation message
    await asyncio.sleep(1)

    # Capture output from all processes
136
137
    server_output = stop_process("server_process", server_process)
    middle_output = stop_process("middle_server_process", middle_server_process)
138
139
140
141
142

    # Assert expected messages
    assert (
        "Client: Cancelling after 3 responses..." in client_output
    ), f"Client output: {client_output}"
143
144
145
    assert (
        "Middle server: Forwarding response 2" in middle_output
    ), f"Middle server output: {middle_output}"
146
147
148
    assert (
        "Server: Cancelled at iteration" in server_output
    ), f"Server output: {server_output}"