conftest.py 15.8 KB
Newer Older
1
2
3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""
Test configuration and fixtures for Dynamo Python bindings tests.

TWO MODES OF OPERATION:

1. Isolated Mode (ENABLE_ISOLATED_ETCD_AND_NATS=1):
   - Each test gets fresh NATS/ETCD on random ports
   - Requires: pytest-forked (uv pip install pytest-forked)
   - Tests using 'runtime' fixture MUST have @pytest.mark.forked
   - Safer, enables parallel execution
   - Run: ENABLE_ISOLATED_ETCD_AND_NATS=1 pytest tests/test_metrics_registry.py -n auto

2. Default Ports Mode (ENABLE_ISOLATED_ETCD_AND_NATS=0, default):
   - All tests share NATS/ETCD on default ports (4222, 2379)
   - No pytest-forked required
   - No @pytest.mark.forked required
   - Faster for sequential runs, but NO parallel execution
   - Run: pytest tests/test_metrics_registry.py

Performance comparison (32-core machine, 13 tests):
    Default ports (ENABLE_ISOLATED_ETCD_AND_NATS=0, default): 4.06s (sequential only)
    Isolated sequential (ENABLE_ISOLATED_ETCD_AND_NATS=1):    8.58s (2.1x slower, but safer)
    Isolated parallel -n 8:   2.82s (1.4x faster than default)
    Isolated parallel -n 16:  2.28s (1.8x faster than default, optimal)
    Isolated parallel -n 32:  2.74s (overhead dominates)

    Recommendation: Default mode for simplicity. Use ENABLE_ISOLATED_ETCD_AND_NATS=1
    with -n 8 to -n 16 when you need parallel execution and maximum test isolation.
"""

34
import asyncio
35
36
37
38
39
import json
import os
import re
import shutil
import socket
40
import subprocess
41
42
import tempfile
import time
43
44
45

import pytest

46
47
from dynamo.runtime import DistributedRuntime

48
49
50
51
52
53
54
55
# Configuration constants
# ENABLE_ISOLATED_ETCD_AND_NATS: When True, each test gets isolated NATS/ETCD instances
# on random ports with unique data directories. This enables parallel test execution.
# Set to False to use default ports (4222, 2379) for sequential execution.
# Can be overridden by environment variable: ENABLE_ISOLATED_ETCD_AND_NATS=0 or =1
ENABLE_ISOLATED_ETCD_AND_NATS = (
    os.environ.get("ENABLE_ISOLATED_ETCD_AND_NATS", "0") == "1"
)
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Check if pytest-forked is installed (only when using isolated NATS/ETCD)
# This is REQUIRED when ENABLE_ISOLATED_ETCD_AND_NATS=1 because each test gets
# fresh services and the DistributedRuntime singleton needs process isolation
if ENABLE_ISOLATED_ETCD_AND_NATS:
    try:
        import pytest_forked  # noqa: F401
    except ImportError:
        pytest.exit(
            """
pytest-forked is required when ENABLE_ISOLATED_ETCD_AND_NATS=1.
Install it with: uv pip install pytest-forked

This is needed because DistributedRuntime is a process-level singleton
and tests must run in separate processes to avoid 'Worker already initialized' errors.

Alternatively, set ENABLE_ISOLATED_ETCD_AND_NATS=0 to use default ports (slower, sequential only).
""",
            returncode=1,
        )

# Timeout constants
SERVICE_STARTUP_TIMEOUT = 5
SERVICE_SHUTDOWN_TIMEOUT = 5


def get_free_port():
    """Find and return an available port."""
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.bind(("", 0))
    port = sock.getsockname()[1]
    sock.close()
    return port


def wait_for_port(host, port, timeout: float = SERVICE_STARTUP_TIMEOUT):
    """Wait for a port to be available."""
    start = time.time()
    while time.time() - start < timeout:
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.settimeout(1)
            sock.connect((host, port))
            sock.close()
            return True
        except (socket.error, ConnectionRefusedError):
            time.sleep(0.1)
    return False


def start_nats_and_etcd_default_ports():
    """
    Start NATS and ETCD on default ports (4222, 2379).

    Use this for sequential test execution or when running tests alone.
    Faster startup if services are already running.
    """
    # Use default ports
    nats_port = 4222
    etcd_client_port = 2379

    # No data directories needed - use defaults
    nats_data_dir = None
    etcd_data_dir = None

    # Check if ports are already in use (reuse them if so)
    # TODO: In the future, error out to ensure proper test isolation
    nats_already_running = wait_for_port("localhost", nats_port, timeout=0.1)
    etcd_already_running = wait_for_port("localhost", etcd_client_port, timeout=0.1)

    if nats_already_running and etcd_already_running:
        print(
            f"Reusing existing NATS on port {nats_port} and ETCD on port {etcd_client_port}"
        )
        # Set environment variables for the runtime to use
        os.environ["NATS_SERVER"] = f"nats://localhost:{nats_port}"
        os.environ["ETCD_ENDPOINTS"] = f"http://localhost:{etcd_client_port}"
        # Return None for processes since we're reusing existing services
        return None, None, nats_port, etcd_client_port, None, None

    # Set environment variables for the runtime to use
    os.environ["NATS_SERVER"] = f"nats://localhost:{nats_port}"
    os.environ["ETCD_ENDPOINTS"] = f"http://localhost:{etcd_client_port}"

    print(f"Using NATS on default port {nats_port}")
    print(f"Using ETCD on default client port {etcd_client_port}")

    # Start services with default ports
144
    nats_server = subprocess.Popen(["nats-server", "-js", "--trace"])
145
146
    etcd = subprocess.Popen(["etcd"])

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    return nats_server, etcd, nats_port, etcd_client_port, nats_data_dir, etcd_data_dir


def start_nats_and_etcd_random_ports():
    """
    Start NATS and ETCD with random ports and unique data directories.

    This ensures test isolation by giving each test module (or parallel worker)
    its own NATS/ETCD instances on different ports with separate data directories.
    This allows tests to run in parallel without port or filesystem conflicts.

    Note: etcd uses port 0 (OS-assigned port) to eliminate race conditions.
    NATS uses get_free_port() with retry logic since it doesn't support port 0.
    Port collision probability per NATS attempt: ~1% (heavy parallel testing), ~0.05% (normal load).
    With 5 retries, probability of all NATS attempts failing: ~1.5e-10 (essentially never).
    """
    # Create unique temporary data directories
    nats_data_dir = tempfile.mkdtemp(prefix="nats_data_")
    etcd_data_dir = tempfile.mkdtemp(prefix="etcd_data_")

    # Start etcd first with port 0 (no race condition, no retries needed)
    print(f"Starting ETCD with port 0 (OS-assigned), data dir: {etcd_data_dir}")
    etcd = subprocess.Popen(
        [
            "etcd",
172
173
            "--logger",
            "zap",
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
            "--data-dir",
            str(etcd_data_dir),
            "--listen-client-urls",
            "http://localhost:0",
            "--advertise-client-urls",
            "http://localhost:0",
            "--listen-peer-urls",
            "http://localhost:0",
            "--initial-advertise-peer-urls",
            "http://localhost:0",
            "--initial-cluster",
            "default=http://localhost:0",
        ],
        stderr=subprocess.PIPE,
        stdout=subprocess.PIPE,
        text=True,
        bufsize=1,
    )

    # Parse etcd's stderr to discover the actual client port it bound to
    etcd_client_port = None
    timeout_at = time.time() + 5.0
196

197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    while time.time() < timeout_at:
        if etcd.poll() is not None:
            stderr = etcd.stderr.read() if etcd.stderr else ""
            shutil.rmtree(nats_data_dir, ignore_errors=True)
            shutil.rmtree(etcd_data_dir, ignore_errors=True)
            raise RuntimeError(f"ETCD failed to start: {stderr}")

        line = etcd.stderr.readline() if etcd.stderr else ""
        if not line:
            time.sleep(0.01)
            continue

        try:
            log = json.loads(line)
            msg = log.get("msg", "")

            # Look for the client port
214
215
216
217
218
            if (
                "serving client traffic" in msg
                or "serving client" in msg
                or "serving insecure client" in msg
            ):
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
                address = log.get("address", "")
                match = re.search(r":(\d+)$", address)
                if match:
                    etcd_client_port = int(match.group(1))
                    print(f"ETCD bound to client port: {etcd_client_port}")
                    break
        except (json.JSONDecodeError, ValueError):
            continue

    if etcd_client_port is None:
        etcd.terminate()
        etcd.wait()
        shutil.rmtree(nats_data_dir, ignore_errors=True)
        shutil.rmtree(etcd_data_dir, ignore_errors=True)
        raise RuntimeError("Failed to discover ETCD client port from logs")

    # Now start NATS with retry logic (up to 5 attempts due to race condition)
    max_nats_retries = 5
    nats_server = None
    nats_port = None
    last_error = None

    for attempt in range(max_nats_retries):
        try:
            nats_port = get_free_port()
            print(
                f"Attempt {attempt + 1}: Starting NATS on port {nats_port}, data dir: {nats_data_dir}"
            )

            nats_server = subprocess.Popen(
                ["nats-server", "-js", "-p", str(nats_port), "-sd", str(nats_data_dir)],
                stderr=subprocess.PIPE,
            )

            # Give NATS a moment to bind to the port
            time.sleep(0.1)

            # Check if NATS failed to start
            if nats_server.poll() is not None:
                stderr = (
                    nats_server.stderr.read().decode() if nats_server.stderr else ""
                )
                if "address already in use" in stderr.lower():
                    print(f"NATS port {nats_port} already in use, retrying...")
                    time.sleep(0.1)
                    continue
                etcd.terminate()
                etcd.wait()
                shutil.rmtree(nats_data_dir, ignore_errors=True)
                shutil.rmtree(etcd_data_dir, ignore_errors=True)
                raise RuntimeError(f"NATS failed to start: {stderr}")

            # Success - NATS started
            break

        except Exception as e:
            last_error = e
            print(f"Attempt {attempt + 1} failed: {e}")
            if attempt < max_nats_retries - 1:
                time.sleep(0.2)
            else:
                etcd.terminate()
                etcd.wait()
                shutil.rmtree(nats_data_dir, ignore_errors=True)
                shutil.rmtree(etcd_data_dir, ignore_errors=True)
                raise RuntimeError(
                    f"Failed to start NATS after {max_nats_retries} attempts: {last_error}"
                )

    # Set environment variables for the runtime to use
    os.environ["NATS_SERVER"] = f"nats://localhost:{nats_port}"
    os.environ["ETCD_ENDPOINTS"] = f"http://localhost:{etcd_client_port}"

    return nats_server, etcd, nats_port, etcd_client_port, nats_data_dir, etcd_data_dir


@pytest.fixture(scope="module", autouse=True)
def nats_and_etcd():
    """
    Start NATS and ETCD for testing.

    Scope is "module" which means each test module shares the same NATS/ETCD instance.

    Behavior is controlled by ENABLE_ISOLATED_ETCD_AND_NATS constant:
    - True (default): Random ports + unique data dirs for parallel execution
    - False: Default ports (4222, 2379) for sequential execution
    """
    if ENABLE_ISOLATED_ETCD_AND_NATS:
        (
            nats_server,
            etcd,
            nats_port,
            etcd_client_port,
            nats_data_dir,
            etcd_data_dir,
        ) = start_nats_and_etcd_random_ports()
    else:
        (
            nats_server,
            etcd,
            nats_port,
            etcd_client_port,
            nats_data_dir,
            etcd_data_dir,
        ) = start_nats_and_etcd_default_ports()

    try:
        # Wait for services to be ready
        if not wait_for_port("localhost", nats_port, timeout=SERVICE_STARTUP_TIMEOUT):
            raise RuntimeError(f"NATS server failed to start on port {nats_port}")
        if not wait_for_port(
            "localhost", etcd_client_port, timeout=SERVICE_STARTUP_TIMEOUT
        ):
            raise RuntimeError(f"ETCD failed to start on port {etcd_client_port}")

        print(f"NATS ({nats_port}) and ETCD ({etcd_client_port}) services ready")
        yield
    finally:
        # Teardown code - always runs even if setup fails or tests error
        print("Tearing down resources")

        # Only terminate services if we started them (not reusing existing)
        if nats_server is None and etcd is None:
            print("Reused existing services, not stopping them")
        else:
            # Terminate both processes first (parallel shutdown)
            try:
                if nats_server:
                    nats_server.terminate()
            except Exception as e:
                print(f"Error terminating NATS: {e}")
            try:
                if etcd:
                    etcd.terminate()
            except Exception as e:
                print(f"Error terminating ETCD: {e}")

            # Wait for both processes to finish
            try:
                if nats_server:
                    nats_server.wait(timeout=SERVICE_SHUTDOWN_TIMEOUT)
            except subprocess.TimeoutExpired:
                print("NATS did not terminate gracefully, killing")
                try:
                    nats_server.kill()
                except Exception:
                    pass
            except Exception as e:
                print(f"Error waiting for NATS: {e}")

            try:
                if etcd:
                    etcd.wait(timeout=SERVICE_SHUTDOWN_TIMEOUT)
            except subprocess.TimeoutExpired:
                print("ETCD did not terminate gracefully, killing")
                try:
                    etcd.kill()
                except Exception:
                    pass
            except Exception as e:
                print(f"Error waiting for ETCD: {e}")

        # Clean up temporary data directories (if created)
        if nats_data_dir:
            try:
                shutil.rmtree(nats_data_dir, ignore_errors=True)
            except Exception as e:
                print(f"Error removing NATS data dir: {e}")
        if etcd_data_dir:
            try:
                shutil.rmtree(etcd_data_dir, ignore_errors=True)
            except Exception as e:
                print(f"Error removing ETCD data dir: {e}")
392
393


394
395
396
397
398
399
400
401
402
403
404
@pytest.fixture(scope="function")
def temp_file_store():
    """
    A temporary directory to use as the key-value store. Cleaned up on test exit.
    Local to the unit test using it.
    """
    with tempfile.TemporaryDirectory() as tmpdir:
        os.environ["DYN_FILE_KV"] = tmpdir
        yield tmpdir


405
@pytest.fixture(scope="function", autouse=False)
406
async def runtime(request):
407
408
    """
    Create a DistributedRuntime for testing.
409
410
411

    IMPORTANT: DistributedRuntime is a process-level singleton. When using isolated
    NATS/ETCD (ENABLE_ISOLATED_ETCD_AND_NATS=1), tests using this fixture MUST be
412
    marked with `@pytest.mark.forked` to run in a separate process for isolation.
413
414
415

    Without @pytest.mark.forked in isolated mode, you will get "Worker already initialized"
    errors when multiple tests try to create runtimes in the same process.
416
    """
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    # Check if the test is marked with @pytest.mark.forked (only in isolated mode)
    if ENABLE_ISOLATED_ETCD_AND_NATS:
        forked_marker = request.node.get_closest_marker("forked")
        if forked_marker is None:
            pytest.fail(
                f"""
Test '{request.node.name}' uses the 'runtime' fixture but is not marked with @pytest.mark.forked.
This is required when ENABLE_ISOLATED_ETCD_AND_NATS=1.

Add @pytest.mark.forked decorator to run this test in a separate process:
  @pytest.mark.forked
  async def test_my_test(runtime):
      ...

Or set ENABLE_ISOLATED_ETCD_AND_NATS=0 to use default ports (no forking needed).

This is required because DistributedRuntime is a process-level singleton.
"""
            )

437
    loop = asyncio.get_running_loop()
438
    runtime = DistributedRuntime(loop, "file", "nats")
439
440
    yield runtime
    runtime.shutdown()