Commit 7e63ef82 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0' into v0.14.0-dev

parents 8cbcac5d b17039bc
...@@ -8,6 +8,7 @@ import pytest ...@@ -8,6 +8,7 @@ import pytest
import torch import torch
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.cpu_attn import _get_attn_isa from vllm.v1.attention.backends.cpu_attn import _get_attn_isa
if not current_platform.is_cpu(): if not current_platform.is_cpu():
...@@ -190,7 +191,7 @@ def varlen_with_paged_kv( ...@@ -190,7 +191,7 @@ def varlen_with_paged_kv(
use_sink: bool, use_sink: bool,
isa: str, isa: str,
) -> None: ) -> None:
current_platform.seed_everything(0) set_random_seed(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens] kv_lens = [x[1] for x in seq_lens]
......
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
try: try:
if current_platform.is_rocm(): if current_platform.is_rocm():
...@@ -132,7 +133,7 @@ def test_varlen_with_paged_kv( ...@@ -132,7 +133,7 @@ def test_varlen_with_paged_kv(
"Flash attention with quantized inputs is only " "Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type" "supported on version 3 with bfloat16 base type"
) )
current_platform.seed_everything(0) set_random_seed(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens] kv_lens = [x[1] for x in seq_lens]
......
...@@ -10,6 +10,7 @@ from tests.kernels.quantization.nvfp4_utils import ( ...@@ -10,6 +10,7 @@ from tests.kernels.quantization.nvfp4_utils import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import set_random_seed
if not current_platform.is_device_capability_family(100): if not current_platform.is_device_capability_family(100):
pytest.skip( pytest.skip(
...@@ -80,7 +81,7 @@ def test_flashinfer_trtllm_decode_with_baseline( ...@@ -80,7 +81,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
has_sinks: bool, has_sinks: bool,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(42) set_random_seed(42)
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_quant_dtype = q_quant_dtype or dtype q_quant_dtype = q_quant_dtype or dtype
...@@ -279,7 +280,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( ...@@ -279,7 +280,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
has_sinks: bool, has_sinks: bool,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(42) set_random_seed(42)
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_quant_dtype = q_quant_dtype or dtype q_quant_dtype = q_quant_dtype or dtype
......
...@@ -7,12 +7,12 @@ import random ...@@ -7,12 +7,12 @@ import random
import pytest import pytest
import torch import torch
from vllm.attention.ops.flashmla import ( from vllm.triton_utils import triton
from vllm.v1.attention.ops.flashmla import (
flash_mla_with_kvcache, flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
is_flashmla_dense_supported, is_flashmla_dense_supported,
) )
from vllm.triton_utils import triton
def cal_diff( def cal_diff(
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
def test_sparse_flashmla_metadata_smoke(): def test_sparse_flashmla_metadata_smoke():
import vllm.attention.ops.flashmla as fm import vllm.v1.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_sparse_supported() ok, reason = fm.is_flashmla_sparse_supported()
if not ok: if not ok:
...@@ -34,7 +34,7 @@ def test_sparse_flashmla_metadata_smoke(): ...@@ -34,7 +34,7 @@ def test_sparse_flashmla_metadata_smoke():
def test_sparse_flashmla_decode_smoke(): def test_sparse_flashmla_decode_smoke():
import vllm.attention.ops.flashmla as fm import vllm.v1.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_sparse_supported() ok, reason = fm.is_flashmla_sparse_supported()
if not ok: if not ok:
...@@ -97,7 +97,7 @@ def test_sparse_flashmla_decode_smoke(): ...@@ -97,7 +97,7 @@ def test_sparse_flashmla_decode_smoke():
def test_sparse_flashmla_prefill_smoke(): def test_sparse_flashmla_prefill_smoke():
import vllm.attention.ops.flashmla as fm import vllm.v1.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_sparse_supported() ok, reason = fm.is_flashmla_sparse_supported()
if not ok: if not ok:
......
...@@ -5,7 +5,7 @@ import pytest ...@@ -5,7 +5,7 @@ import pytest
import torch import torch
from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton
from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed
NUM_HEADS = [4, 8] NUM_HEADS = [4, 8]
HEAD_SIZES = [64] HEAD_SIZES = [64]
...@@ -124,7 +124,7 @@ def test_linear_decode_forward_triton( ...@@ -124,7 +124,7 @@ def test_linear_decode_forward_triton(
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.manual_seed(42) torch.manual_seed(42)
torch.cuda.manual_seed_all(42) torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42) set_random_seed(42)
base = 0.01 base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
...@@ -167,7 +167,7 @@ def test_linear_decode_forward_triton_with_padding( ...@@ -167,7 +167,7 @@ def test_linear_decode_forward_triton_with_padding(
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.manual_seed(42) torch.manual_seed(42)
torch.cuda.manual_seed_all(42) torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42) set_random_seed(42)
batch_size = 4 batch_size = 4
base = 0.01 base = 0.01
...@@ -231,7 +231,7 @@ def test_lightning_attention_reference( ...@@ -231,7 +231,7 @@ def test_lightning_attention_reference(
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.manual_seed(42) torch.manual_seed(42)
torch.cuda.manual_seed_all(42) torch.cuda.manual_seed_all(42)
current_platform.seed_everything(42) set_random_seed(42)
base = 0.01 base = 0.01
q = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype) q = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
......
...@@ -5,10 +5,10 @@ import pytest ...@@ -5,10 +5,10 @@ import pytest
import torch import torch
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm.attention.ops.triton_merge_attn_states import ( from vllm.platforms import current_platform
from vllm.v1.attention.ops.triton_merge_attn_states import (
merge_attn_states as merge_attn_states_triton, merge_attn_states as merge_attn_states_triton,
) )
from vllm.platforms import current_platform
# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 # Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
......
...@@ -3,21 +3,23 @@ ...@@ -3,21 +3,23 @@
""" """
Test: Test:
* Tests for MultiHeadAttention layer * Tests for MMEncoderAttention layer
""" """
import itertools
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -34,7 +36,7 @@ if current_platform.is_rocm(): ...@@ -34,7 +36,7 @@ if current_platform.is_rocm():
@pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("device", devices)
def test_mha_attn_platform(device: str): def test_mha_attn_platform(default_vllm_config, device: str):
""" """
Test the attention selector between different platform and device. Test the attention selector between different platform and device.
""" """
...@@ -42,35 +44,31 @@ def test_mha_attn_platform(device: str): ...@@ -42,35 +44,31 @@ def test_mha_attn_platform(device: str):
if device == "cpu": if device == "cpu":
with ( with (
patch("vllm.attention.layer.current_platform", CpuPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()), patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
): ):
attn = MultiHeadAttention(16, 64, scale=1) attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
elif device == "hip": elif device == "hip":
with ( with (
patch("vllm.attention.layer.current_platform", RocmPlatform()),
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()), patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
): ):
attn = MultiHeadAttention(16, 64, scale=1) attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
else: else:
# Test CUDA with head_size=64 (divisible by 32) # Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention # - should use vLLM's FlashAttention
with ( with (
patch("vllm.attention.layer.current_platform", CudaPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
): ):
attn = MultiHeadAttention(16, 64, scale=1) attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
# Test CUDA with head_size=72 (not divisible by 32) # Test CUDA with head_size=72 (not divisible by 32)
# - should use vLLM's FlashAttention # - should use vLLM's FlashAttention
with ( with (
patch("vllm.attention.layer.current_platform", CudaPlatform()),
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
): ):
attn = MultiHeadAttention(16, 72, scale=1) attn = MMEncoderAttention(16, 72, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
...@@ -94,6 +92,10 @@ def ref_attention( ...@@ -94,6 +92,10 @@ def ref_attention(
BATCH_SIZES = [1, 16] BATCH_SIZES = [1, 16]
SEQ_LENS = [1] SEQ_LENS = [1]
VAR_SEQ_LENS = [
[2, 2],
[2, 3, 4],
]
NUM_HEADS = [1, 16] NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1] NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80] HEAD_SIZES = [64, 80]
...@@ -114,6 +116,7 @@ CUDA_DEVICES = ["cuda"] ...@@ -114,6 +116,7 @@ CUDA_DEVICES = ["cuda"]
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_forward( def test_mha_attn_forward(
default_vllm_config,
batch_size: int, batch_size: int,
seq_len: int, seq_len: int,
num_heads: int, num_heads: int,
...@@ -122,7 +125,7 @@ def test_mha_attn_forward( ...@@ -122,7 +125,7 @@ def test_mha_attn_forward(
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str,
): ):
current_platform.seed_everything(0) set_random_seed(0)
torch.set_default_device(device) torch.set_default_device(device)
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
...@@ -130,7 +133,7 @@ def test_mha_attn_forward( ...@@ -130,7 +133,7 @@ def test_mha_attn_forward(
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size) k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size) v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
scale = 1.0 / head_size**0.5 scale = 1.0 / head_size**0.5
attn = MultiHeadAttention( attn = MMEncoderAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
) )
output = attn(q, k, v) output = attn(q, k, v)
...@@ -151,3 +154,59 @@ def test_mha_attn_forward( ...@@ -151,3 +154,59 @@ def test_mha_attn_forward(
scale=scale, scale=scale,
).reshape(batch_size, seq_len, num_heads * head_size) ).reshape(batch_size, seq_len, num_heads * head_size)
torch.testing.assert_close(output, ref_output) torch.testing.assert_close(output, ref_output)
@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_varlen_forward(
default_vllm_config,
var_seq_len: list[int],
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: str,
):
set_random_seed(0)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
q = torch.randn(1, sum(var_seq_len), num_heads, head_size)
k = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
v = torch.randn(1, sum(var_seq_len), num_kv_heads, head_size)
cu_seqlens = torch.tensor(
[0] + list(itertools.accumulate(var_seq_len)), dtype=torch.int32
)
scale = 1.0 / head_size**0.5
attn = MMEncoderAttention(
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
)
output = attn(
q, k, v, cu_seqlens=cu_seqlens, max_seqlen=torch.tensor(max(var_seq_len))
)
assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
if num_queries_per_kv > 1:
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
ref_output = []
for q_i, k_i, v_i in zip(
torch.split(q, var_seq_len, dim=1),
torch.split(k, var_seq_len, dim=1),
torch.split(v, var_seq_len, dim=1),
):
output_i = ref_attention(
q_i,
k_i,
v_i,
scale=scale,
)
ref_output.append(output_i)
ref_output = torch.cat(ref_output, dim=1)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import torch import torch
from torch.testing import assert_close from torch.testing import assert_close
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
def test_pack_seq_basic_fp8(): def test_pack_seq_basic_fp8():
......
...@@ -10,10 +10,12 @@ import pytest ...@@ -10,10 +10,12 @@ import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed
from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode,
)
from vllm.v1.attention.ops.prefix_prefill import context_attention_fwd
if not current_platform.is_rocm(): if not current_platform.is_rocm():
from xformers import ops as xops from xformers import ops as xops
...@@ -117,6 +119,7 @@ def test_contexted_kv_attention( ...@@ -117,6 +119,7 @@ def test_contexted_kv_attention(
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
op: Callable, op: Callable,
block_size: int = 32,
) -> None: ) -> None:
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89): if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
pytest.skip( pytest.skip(
...@@ -130,7 +133,7 @@ def test_contexted_kv_attention( ...@@ -130,7 +133,7 @@ def test_contexted_kv_attention(
): ):
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache") pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
current_platform.seed_everything(0) set_random_seed(0)
torch.set_default_device(device) torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process # Need this, otherwise when we capture the graph the process
...@@ -143,7 +146,6 @@ def test_contexted_kv_attention( ...@@ -143,7 +146,6 @@ def test_contexted_kv_attention(
MAX_CTX_LEN = 1024 MAX_CTX_LEN = 1024
BS = 10 BS = 10
cache_size = 640 cache_size = 640
block_size = 32
max_block_per_request = 64 max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
# ensure one sequence in batch is a decode # ensure one sequence in batch is a decode
...@@ -338,6 +340,7 @@ def test_contexted_kv_attention_alibi( ...@@ -338,6 +340,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
op: Callable, op: Callable,
block_size: int = 32,
) -> None: ) -> None:
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89): if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
pytest.skip( pytest.skip(
...@@ -351,7 +354,7 @@ def test_contexted_kv_attention_alibi( ...@@ -351,7 +354,7 @@ def test_contexted_kv_attention_alibi(
): ):
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache") pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
current_platform.seed_everything(0) set_random_seed(0)
torch.set_default_device(device) torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process # Need this, otherwise when we capture the graph the process
...@@ -390,7 +393,6 @@ def test_contexted_kv_attention_alibi( ...@@ -390,7 +393,6 @@ def test_contexted_kv_attention_alibi(
MAX_CTX_LEN = 1024 MAX_CTX_LEN = 1024
BS = 10 BS = 10
cache_size = 640 cache_size = 640
block_size = 32
max_block_per_request = 64 max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
...@@ -643,3 +645,34 @@ def test_contexted_kv_attention_alibi_f32( ...@@ -643,3 +645,34 @@ def test_contexted_kv_attention_alibi_f32(
test_contexted_kv_attention_alibi( test_contexted_kv_attention_alibi(
num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op
) )
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_qwen3_nonstandard_block_size(
head_size: int,
dtype: torch.dtype,
device: str,
op: Callable,
) -> None:
"""
A separate test function specifically added
for Qwen3-Next-80B (Block Size 544).
"""
if not current_platform.is_rocm():
pytest.skip("544 block size optimization is only for ROCm.")
test_contexted_kv_attention(
num_heads=64,
num_queries_per_kv=1,
head_size=head_size,
block_size=544,
sliding_window=0,
dtype=dtype,
kv_cache_dtype="auto",
device=device,
op=op,
)
...@@ -4,8 +4,10 @@ ...@@ -4,8 +4,10 @@
import pytest import pytest
import torch import torch
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend, get_attn_backend
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -16,40 +18,56 @@ def clear_cache(): ...@@ -16,40 +18,56 @@ def clear_cache():
@pytest.mark.skip(reason="Skipped for now. Should be revisited.") @pytest.mark.skip(reason="Skipped for now. Should be revisited.")
def test_selector(monkeypatch: pytest.MonkeyPatch): def test_selector(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m: # Set the current platform to ROCm using monkeypatch
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_ATTN") monkeypatch.setattr("vllm.v1.attention.selector.current_platform", RocmPlatform())
# Set the current platform to ROCm using monkeypatch # Test standard ROCm attention
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform()) attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_ATTN)
vllm_config = VllmConfig(attention_config=attention_config)
# Test standard ROCm attention with set_current_vllm_config(vllm_config):
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN" assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN"
# MLA test for deepseek related # MLA test for deepseek related
# Change the attention backend to triton MLA
attention_config = AttentionConfig(backend=AttentionBackendEnum.TRITON_MLA)
vllm_config = VllmConfig(attention_config=attention_config)
# change the attention backend to triton MLA with set_current_vllm_config(vllm_config):
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True) backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
assert backend.get_name() == "TRITON_MLA" assert backend.get_name() == "TRITON_MLA"
# If attention backend is None # If attention backend is None
# If use_mla is true # If use_mla is true
# The selected backend is triton MLA # The selected backend is triton MLA
m.setenv("VLLM_ATTENTION_BACKEND", "") attention_config = AttentionConfig(backend=None)
vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True) backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
assert backend.get_name() == "TRITON_MLA" assert backend.get_name() == "TRITON_MLA"
# change the attention backend to AITER MLA # Change the attention backend to AITER MLA
# m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_MLA") attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_AITER_MLA)
# backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True) vllm_config = VllmConfig(attention_config=attention_config)
# assert backend.get_name() == "ROCM_AITER_MLA"
# with set_current_vllm_config(vllm_config):
# # If attention backend is None # backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
# # If use_mla is true # assert backend.get_name() == "ROCM_AITER_MLA"
# # If VLLM_ROCM_USE_AITER is enabled
# # The selected backend is ROCM_AITER_MLA # # If attention backend is None
# m.setenv("VLLM_ATTENTION_BACKEND", "") # # If use_mla is true
# m.setenv("VLLM_ROCM_USE_AITER", "1") # # If VLLM_ROCM_USE_AITER is enabled
# backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True) # # The selected backend is ROCM_AITER_MLA
# assert backend.get_name() == "ROCM_AITER_MLA" # with monkeypatch.context() as m:
# m.setenv("VLLM_ROCM_USE_AITER", "1")
# attention_config = AttentionConfig(backend=None)
# vllm_config = VllmConfig(attention_config=attention_config)
# with set_current_vllm_config(vllm_config):
# backend = get_attn_backend(
# 576, torch.bfloat16, "auto", 1, False, use_mla=True
# )
# assert backend.get_name() == "ROCM_AITER_MLA"
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import pytest import pytest
import torch import torch
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
@pytest.mark.parametrize("B", [3, 5]) @pytest.mark.parametrize("B", [3, 5])
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn.functional as F
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
def ref_masked_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
is_causal: bool = True,
sliding_window_q: int | None = None,
sliding_window_k: int | None = None,
) -> torch.Tensor:
"""Reference implementation using PyTorch SDPA."""
# q, k, v: [total_tokens, num_heads, head_dim]
# SDPA expects [batch, num_heads, seq_len, head_dim]
total_tokens = q.shape[0]
# Add batch dimension and transpose
q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim]
k = k.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim]
v = v.unsqueeze(0).transpose(1, 2) # [1, num_heads, total_tokens, head_dim]
# Create attention mask if needed
attn_mask = None
use_causal = is_causal
# If we have sliding window or need custom masking, create explicit mask
sliding_window_q = sliding_window_q if sliding_window_q is not None else 0
sliding_window_k = sliding_window_k if sliding_window_k is not None else 0
if (sliding_window_q > 0) or (sliding_window_k > 0):
# Position indices
pos_q = torch.arange(total_tokens, device=q.device).unsqueeze(1)
pos_k = torch.arange(total_tokens, device=q.device).unsqueeze(0)
# Start with valid mask (False = no masking)
mask = torch.ones(
(total_tokens, total_tokens), dtype=torch.bool, device=q.device
)
# Apply causal mask
if is_causal:
mask = mask & (pos_q >= pos_k)
# Apply sliding window masks
sliding_window_mask = torch.ones_like(mask)
if sliding_window_q > 0:
sliding_window_mask &= pos_q - pos_k <= sliding_window_q
if sliding_window_k > 0:
sliding_window_mask &= pos_k - pos_q <= sliding_window_k
mask = mask & sliding_window_mask
attn_mask = torch.where(mask, 0.0, float("-inf")).to(q.dtype)
use_causal = False # Don't use is_causal when providing explicit mask
# Use SDPA
output = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=use_causal, dropout_p=0.0
)
# Convert back to original shape: [total_tokens, num_heads, head_dim]
output = output.transpose(1, 2).squeeze(0)
return output
@pytest.mark.parametrize("B", [5])
@pytest.mark.parametrize("max_seq_len", [1024])
@pytest.mark.parametrize("H_Q", [32])
@pytest.mark.parametrize("H_KV", [32, 8])
@pytest.mark.parametrize("D", [128])
@pytest.mark.parametrize("is_causal", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_context_attention(
B: int,
max_seq_len: int,
H_Q: int,
H_KV: int,
D: int,
is_causal: bool,
dtype: torch.dtype,
):
"""Test basic context attention without sliding window."""
torch.manual_seed(42)
# Generate random sequence lengths for each batch
seq_lens = torch.randint(max_seq_len // 2, max_seq_len + 1, (B,), device="cuda")
total_tokens = seq_lens.sum().item()
# Create batch start locations
b_start_loc = torch.zeros(B, dtype=torch.int32, device="cuda")
b_start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
# Create input tensors
q = torch.randn(total_tokens, H_Q, D, dtype=dtype, device="cuda")
k = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
v = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
o = torch.zeros_like(q)
# Call Triton kernel
context_attention_fwd(
q,
k,
v,
o,
b_start_loc,
seq_lens,
max_seq_len,
is_causal=is_causal,
sliding_window_q=None,
sliding_window_k=None,
)
# Compute reference output for each sequence in batch
o_ref = torch.zeros_like(q)
for i in range(B):
start = b_start_loc[i].item()
end = start + seq_lens[i].item()
q_seq = q[start:end]
k_seq = k[start:end]
v_seq = v[start:end]
# Expand KV heads if using GQA
if H_Q != H_KV:
kv_group_num = H_Q // H_KV
k_seq = k_seq.repeat_interleave(kv_group_num, dim=1)
v_seq = v_seq.repeat_interleave(kv_group_num, dim=1)
o_ref[start:end] = ref_masked_attention(
q_seq,
k_seq,
v_seq,
is_causal=is_causal,
sliding_window_q=None,
sliding_window_k=None,
)
# Compare outputs
torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize("B", [4])
@pytest.mark.parametrize("max_seq_len", [1024])
@pytest.mark.parametrize("H_Q", [32])
@pytest.mark.parametrize("H_KV", [32, 8])
@pytest.mark.parametrize("D", [128])
@pytest.mark.parametrize("sliding_window", [(32, 32), (32, 0), (0, 32)])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_context_attention_sliding_window(
B: int,
max_seq_len: int,
H_Q: int,
H_KV: int,
D: int,
sliding_window: tuple[int, int],
dtype: torch.dtype,
):
sliding_window_q, sliding_window_k = sliding_window
"""Test context attention with sliding window."""
torch.manual_seed(42)
# Generate random sequence lengths for each batch
seq_lens = torch.randint(max_seq_len // 2, max_seq_len + 1, (B,), device="cuda")
total_tokens = seq_lens.sum().item()
# Create batch start locations
b_start_loc = torch.zeros(B, dtype=torch.int32, device="cuda")
b_start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
# Create input tensors
q = torch.randn(total_tokens, H_Q, D, dtype=dtype, device="cuda")
k = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
v = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda")
o = torch.zeros_like(q)
# Call Triton kernel
context_attention_fwd(
q,
k,
v,
o,
b_start_loc,
seq_lens,
max_seq_len,
is_causal=False,
sliding_window_q=sliding_window_q,
sliding_window_k=sliding_window_k,
)
# Compute reference output for each sequence in batch
o_ref = torch.zeros_like(q)
for i in range(B):
start = b_start_loc[i].item()
end = start + seq_lens[i].item()
q_seq = q[start:end]
k_seq = k[start:end]
v_seq = v[start:end]
# Expand KV heads if using GQA
if H_Q != H_KV:
kv_group_num = H_Q // H_KV
k_seq = k_seq.repeat_interleave(kv_group_num, dim=1)
v_seq = v_seq.repeat_interleave(kv_group_num, dim=1)
o_ref[start:end] = ref_masked_attention(
q_seq,
k_seq,
v_seq,
is_causal=False,
sliding_window_q=sliding_window_q if sliding_window_q > 0 else None,
sliding_window_k=sliding_window_k if sliding_window_k > 0 else None,
)
# Compare outputs
torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2)
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
import pytest import pytest
import torch import torch
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2 from vllm.utils.math_utils import next_power_of_2
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
NUM_HEADS = [(4, 4), (8, 2)] NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
...@@ -113,7 +114,7 @@ def test_triton_unified_attn( ...@@ -113,7 +114,7 @@ def test_triton_unified_attn(
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) set_random_seed(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens] kv_lens = [x[1] for x in seq_lens]
......
...@@ -6,11 +6,13 @@ from unittest.mock import patch ...@@ -6,11 +6,13 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend, get_attn_backend
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -73,18 +75,18 @@ def generate_params(): ...@@ -73,18 +75,18 @@ def generate_params():
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params()) @pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
def test_env( def test_backend_selection(
device: str, device: str,
name: str, name: str,
use_mla: bool, use_mla: bool,
block_size: int, block_size: int,
monkeypatch: pytest.MonkeyPatch,
): ):
"""Test attention backend selection with valid device-backend pairs.""" """Test attention backend selection with valid device-backend pairs."""
with monkeypatch.context() as m: # Create AttentionConfig with the specified backend
m.setenv("VLLM_ATTENTION_BACKEND", name) attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
if device == "cpu": if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()): with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size) backend = get_attn_backend(16, torch.float16, None, block_size)
...@@ -180,7 +182,7 @@ def test_env( ...@@ -180,7 +182,7 @@ def test_env(
expected = name expected = name
assert backend.get_name() == expected assert backend.get_name() == expected
elif name == "FLASH_ATTN_MLA": elif name == "FLASH_ATTN_MLA":
from vllm.attention.utils.fa_utils import ( from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_mla, flash_attn_supports_mla,
) )
...@@ -217,27 +219,32 @@ def test_env( ...@@ -217,27 +219,32 @@ def test_env(
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_fp32_fallback(device: str): def test_fp32_fallback(device: str):
"""Test attention backend selection with fp32.""" """Test attention backend selection with fp32."""
if device == "cpu": # Use default config (no backend specified)
with patch("vllm.platforms.current_platform", CpuPlatform()): vllm_config = VllmConfig()
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "CPU_ATTN"
elif device == "cuda": with set_current_vllm_config(vllm_config):
with patch("vllm.platforms.current_platform", CudaPlatform()): if device == "cpu":
backend = get_attn_backend(16, torch.float32, None, 16) with patch("vllm.platforms.current_platform", CpuPlatform()):
assert backend.get_name() == "FLEX_ATTENTION" backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "CPU_ATTN"
elif device == "cuda":
with patch("vllm.platforms.current_platform", CudaPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "FLEX_ATTENTION"
def test_flash_attn(monkeypatch: pytest.MonkeyPatch): def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
"""Test FlashAttn validation.""" """Test FlashAttn validation."""
pytest.skip( pytest.skip(
"Skipping as current backend selector does not " "Skipping as current backend selector does not "
"handle fallbacks when a backend is set via env var." "handle fallbacks when a backend is explicitly set."
) )
with monkeypatch.context() as m: attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
m.setenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN") vllm_config = VllmConfig(attention_config=attention_config)
with set_current_vllm_config(vllm_config):
# Unsupported CUDA arch # Unsupported CUDA arch
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5)) monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
backend = get_attn_backend(16, torch.float16, None, 16) backend = get_attn_backend(16, torch.float16, None, 16)
...@@ -277,15 +284,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): ...@@ -277,15 +284,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
assert backend.get_name() != "FLASH_ATTN" assert backend.get_name() != "FLASH_ATTN"
def test_invalid_env(monkeypatch: pytest.MonkeyPatch): def test_invalid_backend():
"""Test that invalid attention backend names raise ValueError.""" """Test that invalid attention backend names raise ValueError."""
with ( with (
monkeypatch.context() as m, pytest.raises(ValueError),
patch("vllm.platforms.current_platform", CudaPlatform()),
): ):
m.setenv("VLLM_ATTENTION_BACKEND", "INVALID") # Invalid backend name should raise ValueError when creating enum
AttentionConfig(backend=AttentionBackendEnum["INVALID"])
# Should raise ValueError for invalid backend
with pytest.raises(ValueError) as exc_info:
get_attn_backend(32, torch.float16, None, 16)
assert "Invalid value 'INVALID'" in str(exc_info.value)
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import pytest import pytest
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
try: try:
import flashinfer import flashinfer
...@@ -101,7 +102,7 @@ def test_flashinfer_decode_with_paged_kv( ...@@ -101,7 +102,7 @@ def test_flashinfer_decode_with_paged_kv(
sliding_window: int | None, sliding_window: int | None,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) set_random_seed(0)
num_seqs = len(kv_lens) num_seqs = len(kv_lens)
num_query_heads = num_heads[0] num_query_heads = num_heads[0]
num_kv_heads = num_heads[1] num_kv_heads = num_heads[1]
...@@ -196,7 +197,7 @@ def test_flashinfer_prefill_with_paged_kv( ...@@ -196,7 +197,7 @@ def test_flashinfer_prefill_with_paged_kv(
sliding_window: int | None, sliding_window: int | None,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) set_random_seed(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens] kv_lens = [x[1] for x in seq_lens]
...@@ -299,7 +300,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv( ...@@ -299,7 +300,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
) -> None: ) -> None:
pytest.skip("TODO: fix the accuracy issue") pytest.skip("TODO: fix the accuracy issue")
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) set_random_seed(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens] kv_lens = [x[1] for x in seq_lens]
...@@ -409,7 +410,7 @@ def test_flashinfer_decode_with_paged_fp8_kv( ...@@ -409,7 +410,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
) -> None: ) -> None:
# test doesn't work for num_heads = (16,16) # test doesn't work for num_heads = (16,16)
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) set_random_seed(0)
num_seqs = len(kv_lens) num_seqs = len(kv_lens)
num_query_heads = num_heads[0] num_query_heads = num_heads[0]
num_kv_heads = num_heads[1] num_kv_heads = num_heads[1]
......
...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.activation import ( ...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.activation import (
SiluAndMul, SiluAndMul,
SwigluOAIAndMul, SwigluOAIAndMul,
) )
from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
...@@ -45,6 +45,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e ...@@ -45,6 +45,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_act_and_mul( def test_act_and_mul(
default_vllm_config,
activation: str, activation: str,
num_tokens: int, num_tokens: int,
d: int, d: int,
...@@ -52,7 +53,7 @@ def test_act_and_mul( ...@@ -52,7 +53,7 @@ def test_act_and_mul(
seed: int, seed: int,
device: str, device: str,
) -> None: ) -> None:
current_platform.seed_everything(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype) x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation == "silu_and_mul": if activation == "silu_and_mul":
...@@ -122,6 +123,7 @@ def test_act_and_mul( ...@@ -122,6 +123,7 @@ def test_act_and_mul(
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_activation( def test_activation(
default_vllm_config,
activation: type[torch.nn.Module], activation: type[torch.nn.Module],
num_tokens: int, num_tokens: int,
d: int, d: int,
...@@ -129,7 +131,7 @@ def test_activation( ...@@ -129,7 +131,7 @@ def test_activation(
seed: int, seed: int,
device: str, device: str,
) -> None: ) -> None:
current_platform.seed_everything(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype) x = torch.randn(num_tokens, d, dtype=dtype)
layer = activation[0]() layer = activation[0]()
......
...@@ -8,11 +8,13 @@ from tests.kernels.utils import opcheck ...@@ -8,11 +8,13 @@ from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
DTYPES = [torch.bfloat16, torch.float16] DTYPES = [torch.bfloat16, torch.float16]
IS_NEOX = [True, False] IS_NEOX = [True, False]
EPS_VALUES = [1e-5, 1e-6] EPS_VALUES = [1e-5, 1e-6]
SEEDS = [13] SEEDS = [13]
PARTIAL_ROPE = [True, False]
CUDA_DEVICES = ["cuda:0"] CUDA_DEVICES = ["cuda:0"]
...@@ -52,16 +54,19 @@ def _apply_qk_norm_rope( ...@@ -52,16 +54,19 @@ def _apply_qk_norm_rope(
@pytest.mark.parametrize("is_neox", IS_NEOX) @pytest.mark.parametrize("is_neox", IS_NEOX)
@pytest.mark.parametrize("eps", EPS_VALUES) @pytest.mark.parametrize("eps", EPS_VALUES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("rotary_ratio", [1.0, 0.5, 0.25])
@torch.inference_mode() @torch.inference_mode()
def test_fused_qk_norm_rope_matches_reference( def test_fused_qk_norm_rope_matches_reference(
default_vllm_config,
device: str, device: str,
dtype: torch.dtype, dtype: torch.dtype,
is_neox: bool, is_neox: bool,
eps: float, eps: float,
seed: int, seed: int,
rotary_ratio: float,
): ):
torch.set_default_device(device) torch.set_default_device(device)
current_platform.seed_everything(seed) set_random_seed(seed)
num_heads, num_kv_heads, head_dim = 16, 4, 128 num_heads, num_kv_heads, head_dim = 16, 4, 128
num_tokens = 4 num_tokens = 4
...@@ -76,10 +81,10 @@ def test_fused_qk_norm_rope_matches_reference( ...@@ -76,10 +81,10 @@ def test_fused_qk_norm_rope_matches_reference(
k_norm.weight.data.normal_(mean=1.0, std=0.1) k_norm.weight.data.normal_(mean=1.0, std=0.1)
q_weight = q_norm.weight.data q_weight = q_norm.weight.data
k_weight = k_norm.weight.data k_weight = k_norm.weight.data
rotary_dim = int(head_dim * rotary_ratio)
rope = RotaryEmbedding( rope = RotaryEmbedding(
head_size=head_dim, head_size=head_dim,
rotary_dim=head_dim, rotary_dim=rotary_dim,
max_position_embeddings=4096, max_position_embeddings=4096,
base=10000.0, base=10000.0,
is_neox_style=is_neox, is_neox_style=is_neox,
......
...@@ -147,6 +147,7 @@ def ops_impl( ...@@ -147,6 +147,7 @@ def ops_impl(
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_rms_norm( def test_rms_norm(
default_vllm_config,
num_tokens: int, num_tokens: int,
hidden_size: int, hidden_size: int,
add_residual: bool, add_residual: bool,
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from tests.kernels.quant_utils import FP8_DTYPE from tests.kernels.quant_utils import FP8_DTYPE
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
...@@ -26,6 +26,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e ...@@ -26,6 +26,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e
@pytest.mark.parametrize("strided_input", [False, True]) @pytest.mark.parametrize("strided_input", [False, True])
@torch.inference_mode() @torch.inference_mode()
def test_rms_norm( def test_rms_norm(
default_vllm_config,
num_tokens: int, num_tokens: int,
hidden_size: int, hidden_size: int,
add_residual: bool, add_residual: bool,
...@@ -34,7 +35,7 @@ def test_rms_norm( ...@@ -34,7 +35,7 @@ def test_rms_norm(
device: str, device: str,
strided_input: bool, strided_input: bool,
) -> None: ) -> None:
current_platform.seed_everything(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
layer = RMSNorm(hidden_size).to(dtype=dtype) layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1) layer.weight.data.normal_(mean=1.0, std=0.1)
...@@ -70,6 +71,80 @@ def test_rms_norm( ...@@ -70,6 +71,80 @@ def test_rms_norm(
) )
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_scale", [0.01, 1.0, 10.0])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("strided_input", [False, True])
def test_fused_rms_norm_quant(
num_tokens: int,
hidden_size: int,
add_residual: bool,
dtype: torch.dtype,
quant_scale: float,
seed: int,
device: str,
strided_input: bool,
) -> None:
set_random_seed(seed)
torch.set_default_device(device)
weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
last_dim = 2 * hidden_size if strided_input else hidden_size
x_base = torch.randn(num_tokens, last_dim, dtype=dtype)
x = x_base[..., :hidden_size]
assert x.is_contiguous() != strided_input
x *= scale
if add_residual:
residual = torch.randn_like(x) * scale
residual_fused = residual.clone()
else:
residual = residual_fused = None
out_norm = torch.empty_like(x)
out_quant = torch.empty_like(x, dtype=FP8_DTYPE)
out_quant_fused = torch.empty_like(out_quant)
quant_scale_t = torch.tensor(quant_scale, dtype=torch.float32)
if add_residual:
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6
)
# Unfused kernel is in-place so it goes second
# Also use a separate clone of x to avoid modifying the input
x_unfused_base = x_base.clone()
x_unfused = x_unfused_base[..., :hidden_size]
assert x_unfused.is_contiguous() != strided_input
torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
torch.ops._C.static_scaled_fp8_quant(
out_quant, x_unfused.contiguous(), quant_scale_t
)
torch.cuda.synchronize()
torch.testing.assert_close(residual_fused, residual, atol=1e-2, rtol=1e-2)
opcheck(
torch.ops._C.fused_add_rms_norm_static_fp8_quant,
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6),
)
else:
torch.ops._C.rms_norm_static_fp8_quant(
out_quant_fused, x, weight, quant_scale_t, 1e-6
)
torch.ops._C.rms_norm(out_norm, x, weight, 1e-6)
torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm, quant_scale_t)
opcheck(
torch.ops._C.rms_norm_static_fp8_quant,
(out_quant_fused, x, weight, quant_scale_t, 1e-6),
)
# @pytest.mark.parametrize("num_tokens", NUM_TOKENS) # @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) # @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment