Unverified Commit f42a006b authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[Bugfix]: During testing, use pytest monkeypatch for safely overriding the env...

[Bugfix]: During testing, use pytest monkeypatch for safely overriding the env var that indicates the vLLM backend (#5210)
parent 3a434b07
import os
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import torch import torch
from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL,
override_backend_env_variable)
from vllm.attention.selector import which_attn_to_use from vllm.attention.selector import which_attn_to_use
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_env(name: str, device: str): def test_env(name: str, device: str, monkeypatch):
"""Test that the attention selector can be set via environment variable. """Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend. Note that we do not test FlashAttn because it is the default backend.
""" """
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = name override_backend_env_variable(monkeypatch, name)
if device == "cpu": if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True): with patch("vllm.attention.selector.is_cpu", return_value=True):
...@@ -32,14 +33,11 @@ def test_env(name: str, device: str): ...@@ -32,14 +33,11 @@ def test_env(name: str, device: str):
torch.float16, 16) torch.float16, 16)
assert backend.name == name assert backend.name == name
if name_backup is not None:
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup
def test_flash_attn(): def test_flash_attn(monkeypatch):
"""Test FlashAttn validation.""" """Test FlashAttn validation."""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
# Unsupported CUDA arch # Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=[7, 5]): with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
...@@ -71,14 +69,9 @@ def test_flash_attn(): ...@@ -71,14 +69,9 @@ def test_flash_attn():
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN" assert backend.name != "FLASH_ATTN"
if name_backup is not None:
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup
def test_invalid_env(): def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid.""" """Throw an exception if the backend name is invalid."""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID"
with pytest.raises(ValueError): with pytest.raises(ValueError):
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup
"""Kernel test utils"""
import pytest
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
backend_name: str) -> None:
'''
Override the environment variable indicating the vLLM backend temporarily,
using pytest monkeypatch to ensure that the env vars get
reset once the test context exits.
Arguments:
* mpatch: pytest monkeypatch instance
* backend_name: attention backend name to force
'''
mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment