"vscode:/vscode.git/clone" did not exist on "883131544faf78f31f85a0350f74ea913ee6ef9c"
test_cuda_context.py 2.74 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import ctypes
from concurrent.futures import ThreadPoolExecutor

import pytest
import torch

from vllm.platforms import current_platform


def check_cuda_context():
    """Check CUDA driver context status"""
    try:
16
        cuda = ctypes.CDLL("libcuda.so")
17
18
19
20
21
22
23
24
25
26
27
28
29
        device = ctypes.c_int()
        result = cuda.cuCtxGetDevice(ctypes.byref(device))
        return (True, device.value) if result == 0 else (False, None)
    except Exception:
        return False, None


def run_cuda_test_in_thread(device_input, expected_device_id):
    """Run CUDA context test in separate thread for isolation"""
    try:
        # New thread should have no CUDA context initially
        valid_before, device_before = check_cuda_context()
        if valid_before:
30
31
32
33
34
            return (
                False,
                "CUDA context should not exist in new thread, "
                f"got device {device_before}",
            )
35
36
37
38
39
40
41
42
43

        # Test setting CUDA context
        current_platform.set_device(device_input)

        # Verify context is created correctly
        valid_after, device_id = check_cuda_context()
        if not valid_after:
            return False, "CUDA context should be valid after set_cuda_context"
        if device_id != expected_device_id:
44
            return False, f"Expected device {expected_device_id}, got {device_id}"
45
46
47
48
49
50
51
52
53

        return True, "Success"
    except Exception as e:
        return False, f"Exception in thread: {str(e)}"


class TestSetCudaContext:
    """Test suite for the set_cuda_context function."""

54
55
56
57
58
59
60
61
62
63
64
    @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
    @pytest.mark.parametrize(
        argnames="device_input,expected_device_id",
        argvalues=[
            (0, 0),
            (torch.device("cuda:0"), 0),
            ("cuda:0", 0),
        ],
        ids=["int", "torch_device", "string"],
    )
    def test_set_cuda_context_parametrized(self, device_input, expected_device_id):
65
66
        """Test setting CUDA context in isolated threads."""
        with ThreadPoolExecutor(max_workers=1) as executor:
67
68
69
            future = executor.submit(
                run_cuda_test_in_thread, device_input, expected_device_id
            )
70
71
72
            success, message = future.result(timeout=30)
        assert success, message

73
    @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
74
75
76
    def test_set_cuda_context_invalid_device_type(self):
        """Test error handling for invalid device type."""
        with pytest.raises(ValueError, match="Expected a cuda device"):
77
            current_platform.set_device(torch.device("cpu"))
78
79
80
81


if __name__ == "__main__":
    pytest.main([__file__, "-v"])