conftest.py 1.12 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pytest configuration for vLLM tests."""

5
6
import warnings

7
8
9
10
11
12
13
14
15
16
17
18
import torch

from vllm.platforms import current_platform


def pytest_configure(config):
    """Disable Flash/MemEfficient SDP on ROCm to avoid HF
    Transformers accuracy issues.
    """
    if not current_platform.is_rocm():
        return

19
20
21
22
23
24
25
26
    skip_patterns = ["test_granite_speech.py"]
    if any(pattern in str(arg) for arg in config.args for pattern in skip_patterns):
        # Skip disabling SDP for Granite Speech tests on ROCm
        return

    # Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
    # accuracy issues
    # TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
27
28
29
    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_math_sdp(True)
30
31
32
33
34
35
    warnings.warn(
        "ROCm: Disabled flash_sdp and mem_efficient_sdp, enabled math_sdp "
        "to avoid HuggingFace Transformers accuracy issues",
        UserWarning,
        stacklevel=1,
    )