test_module.py 2.22 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
import subprocess
import sys

import pytest
import torch


def run_python_script(script_name, timeout):
12
    script_name = f"kv_transfer/{script_name}"
13
14
15
16
    try:
        # Start both processes asynchronously using Popen
        process0 = subprocess.Popen(
            [sys.executable, script_name],
17
            env={"RANK": "0"},  # Set the RANK environment variable for process 0
18
19
20
21
22
23
            stdout=sys.stdout,  # Pipe stdout to current stdout
            stderr=sys.stderr,  # Pipe stderr to current stderr
        )

        process1 = subprocess.Popen(
            [sys.executable, script_name],
24
            env={"RANK": "1"},  # Set the RANK environment variable for process 1
25
26
27
28
29
30
31
32
33
34
            stdout=sys.stdout,  # Pipe stdout to current stdout
            stderr=sys.stderr,  # Pipe stderr to current stderr
        )

        # Wait for both processes to complete, with a timeout
        process0.wait(timeout=timeout)
        process1.wait(timeout=timeout)

        # Check the return status of both processes
        if process0.returncode != 0:
35
            pytest.fail(f"Test {script_name} failed for RANK=0, {process0.returncode}")
36
        if process1.returncode != 0:
37
            pytest.fail(f"Test {script_name} failed for RANK=1, {process1.returncode}")
38
39
40
41
42
43
44
45
46
47
48
49
50
51

    except subprocess.TimeoutExpired:
        # If either process times out, terminate both and fail the test
        process0.terminate()
        process1.terminate()
        pytest.fail(f"Test {script_name} timed out")
    except Exception as e:
        pytest.fail(f"Test {script_name} failed with error: {str(e)}")


# Define the test cases using pytest's parametrize
@pytest.mark.parametrize(
    "script_name,timeout",
    [
52
53
54
55
        ("test_lookup_buffer.py", 60),  # Second test case with a 60-second timeout
        ("test_send_recv.py", 120),  # First test case with a 120-second timeout
    ],
)
56
57
58
def test_run_python_script(script_name, timeout):
    # Check the number of GPUs
    if torch.cuda.device_count() < 2:
59
        pytest.skip(f"Skipping test {script_name} because <2 GPUs are available")
60
61
62

    # Run the test if there are at least 2 GPUs
    run_python_script(script_name, timeout)