Unverified Commit 69f8a0ea authored by Rabi Mishra's avatar Rabi Mishra Committed by GitHub
Browse files

fix(rocm): Use refresh_env_variables() for rocm_aiter_ops in test_moe (#31711)


Signed-off-by: default avatarrabi <ramishra@redhat.com>
parent f28125d8
...@@ -6,8 +6,6 @@ Run `pytest tests/kernels/test_moe.py`. ...@@ -6,8 +6,6 @@ Run `pytest tests/kernels/test_moe.py`.
""" """
import functools import functools
import importlib
import sys
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
...@@ -592,14 +590,12 @@ def test_mixtral_moe( ...@@ -592,14 +590,12 @@ def test_mixtral_moe(
"""Make sure our Mixtral MoE implementation agrees with the one from """Make sure our Mixtral MoE implementation agrees with the one from
huggingface.""" huggingface."""
# clear the cache before every test # Explicitly set AITER env var based on test parameter to ensure
# Force reload aiter_ops to pick up the new environment variables. # consistent behavior regardless of external environment
if "rocm_aiter_ops" in sys.modules: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1" if use_rocm_aiter else "0")
importlib.reload(rocm_aiter_ops) rocm_aiter_ops.refresh_env_variables()
if use_rocm_aiter: if use_rocm_aiter and dtype == torch.float32:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32") pytest.skip("AITER ROCm test skip for float32")
monkeypatch.setenv("RANK", "0") monkeypatch.setenv("RANK", "0")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment