Commit 66b809cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.2' into v0.7.2-dev

parents 37b63c24 0408efc6
# SPDX-License-Identifier: Apache-2.0
"""Tests for the MOE layers.
Run `pytest tests/kernels/test_moe.py`.
......
# SPDX-License-Identifier: Apache-2.0
from itertools import accumulate, product
from typing import Dict, List, Optional
......
# SPDX-License-Identifier: Apache-2.0
import math
import random
import time
......
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import patch
import pytest
import torch
from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.platforms.rocm import RocmPlatform
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()
def test_selector(monkeypatch):
"""Test that the attention selector for ROCm.
"""
override_backend_env_variable(monkeypatch, "ROCM_FLASH")
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert backend.get_name() == "ROCM_FLASH"
# mla test for deepseek related
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
False, True)
assert backend.get_name() == "TRITON_MLA"
# SPDX-License-Identifier: Apache-2.0
"""
Tests for miscellaneous utilities
"""
......
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
......
# SPDX-License-Identifier: Apache-2.0
"""Tests for the triton_scaled_mm kernel
Run `pytest tests/kernels/test_triton_scaled_mm.py`.
......
# SPDX-License-Identifier: Apache-2.0
"""
Tests for miscellaneous utilities
"""
......
# SPDX-License-Identifier: Apache-2.0
import torch
from tests.kernels.utils import opcheck
......
# SPDX-License-Identifier: Apache-2.0
import os
import pytest
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple
import flashinfer
......
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
......
# SPDX-License-Identifier: Apache-2.0
import gguf
import pytest
import torch
......
# SPDX-License-Identifier: Apache-2.0
import torch
from tests.kernels.utils import opcheck
......
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import torch.nn.functional as F
......
# SPDX-License-Identifier: Apache-2.0
"""Tests for the marlin kernel.
Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
......
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
......
# SPDX-License-Identifier: Apache-2.0
"""Kernel test utils"""
import itertools
......
# SPDX-License-Identifier: Apache-2.0
import os
import subprocess
import sys
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment