conftest.py 1.01 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from contextlib import contextmanager
from unittest.mock import MagicMock, patch

import pytest

from vllm.platforms.interface import DeviceCapability


@pytest.fixture
def mock_cuda_platform():
    """
    Fixture that returns a factory for creating mocked CUDA platforms.

    Usage:
        def test_something(mock_cuda_platform):
            with mock_cuda_platform(is_cuda=True, capability=(9, 0)):
                # test code
    """

    @contextmanager
    def _mock_platform(is_cuda: bool = True, capability: tuple[int, int] | None = None):
        mock_platform = MagicMock()
        mock_platform.is_cuda.return_value = is_cuda
        if capability is not None:
            mock_platform.get_device_capability.return_value = DeviceCapability(
                *capability
            )
        with patch("vllm.platforms.current_platform", mock_platform):
            yield mock_platform

    return _mock_platform