Commit 469e903b authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.2' into v0.8.2-dev

parents 389ebcf7 25f560a6
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import random import random
from typing import List, Optional, Tuple from typing import Optional
import pytest import pytest
import torch import torch
...@@ -17,6 +17,8 @@ if not current_platform.is_rocm(): ...@@ -17,6 +17,8 @@ if not current_platform.is_rocm():
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm.attention.backends.xformers import _make_alibi_bias
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability. # This will change depending on the compute capability.
# - 512 as a buffer # - 512 as a buffer
...@@ -25,6 +27,7 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 ...@@ -25,6 +27,7 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
# Reduce NUM_BLOCKS when it happens. # Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321 # Arbitrary values for testing NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512 PARTITION_SIZE = 512
PARTITION_SIZE_ROCM = 256
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [ DTYPES = [
torch.half, torch.bfloat16, torch.float torch.half, torch.bfloat16, torch.float
...@@ -85,8 +88,8 @@ def ref_single_query_cached_kv_attention( ...@@ -85,8 +88,8 @@ def ref_single_query_cached_kv_attention(
block_table = block_tables_lst[i] block_table = block_tables_lst[i]
seq_len = int(seq_lens_lst[i]) seq_len = int(seq_lens_lst[i])
keys_lst: List[torch.Tensor] = [] keys_lst: list[torch.Tensor] = []
values_lst: List[torch.Tensor] = [] values_lst: list[torch.Tensor] = []
for j in range(seq_len): for j in range(seq_len):
block_number = int(block_table[j // block_size]) block_number = int(block_table[j // block_size])
block_offset = j % block_size block_offset = j % block_size
...@@ -133,7 +136,7 @@ def test_paged_attention( ...@@ -133,7 +136,7 @@ def test_paged_attention(
kv_cache_factory, kv_cache_factory,
version: str, version: str,
num_seqs: int, num_seqs: int,
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
use_alibi: bool, use_alibi: bool,
block_size: int, block_size: int,
...@@ -146,6 +149,8 @@ def test_paged_attention( ...@@ -146,6 +149,8 @@ def test_paged_attention(
or (version == "rocm" and head_size not in (64, 128))): or (version == "rocm" and head_size not in (64, 128))):
pytest.skip() pytest.skip()
global PARTITION_SIZE
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
...@@ -166,7 +171,7 @@ def test_paged_attention( ...@@ -166,7 +171,7 @@ def test_paged_attention(
# Create the block tables. # Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables_lst: List[List[int]] = [] block_tables_lst: list[list[int]] = []
for _ in range(num_seqs): for _ in range(num_seqs):
block_table = [ block_table = [
random.randint(0, NUM_BLOCKS - 1) random.randint(0, NUM_BLOCKS - 1)
...@@ -214,6 +219,9 @@ def test_paged_attention( ...@@ -214,6 +219,9 @@ def test_paged_attention(
and block_size == BLOCK_SIZES[0])) and block_size == BLOCK_SIZES[0]))
elif version in ("v2", "rocm"): elif version in ("v2", "rocm"):
if current_platform.is_rocm() and version == "rocm":
PARTITION_SIZE = PARTITION_SIZE_ROCM
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0 assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape num_seqs, num_heads, head_size = output.shape
...@@ -334,25 +342,31 @@ def test_paged_attention( ...@@ -334,25 +342,31 @@ def test_paged_attention(
def ref_multi_query_kv_attention( def ref_multi_query_kv_attention(
cu_seq_lens: List[int], cu_seq_lens: list[int],
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
scale: float, scale: float,
alibi_bias: Optional[list[torch.Tensor]],
dtype: torch.dtype, dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
num_seqs = len(cu_seq_lens) - 1 num_seqs = len(cu_seq_lens) - 1
ref_outputs: List[torch.Tensor] = [] ref_outputs: list[torch.Tensor] = []
if alibi_bias:
assert len(alibi_bias) == num_seqs
for i in range(num_seqs): for i in range(num_seqs):
start_idx = cu_seq_lens[i] start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1] end_idx = cu_seq_lens[i + 1]
seq_len = end_idx - start_idx seq_len = end_idx - start_idx
# Create attention mask. # Create attention mask. ALiBi already includes a tril causal mask.
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), if alibi_bias:
diagonal=1) attn_mask = alibi_bias[i]
attn_mask = attn_mask * torch.finfo(dtype).min else:
attn_mask = attn_mask.to(dtype=dtype) attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype)
ref_output = ref_masked_attention( ref_output = ref_masked_attention(
query[start_idx:end_idx], query[start_idx:end_idx],
...@@ -366,7 +380,6 @@ def ref_multi_query_kv_attention( ...@@ -366,7 +380,6 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0) return torch.cat(ref_outputs, dim=0)
# TODO(woosuk): Add tests for USE_ALIBI=True.
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
...@@ -378,11 +391,12 @@ def ref_multi_query_kv_attention( ...@@ -378,11 +391,12 @@ def ref_multi_query_kv_attention(
@torch.inference_mode() @torch.inference_mode()
def test_multi_query_kv_attention( def test_multi_query_kv_attention(
num_seqs: int, num_seqs: int,
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: str, device: str,
use_alibi: bool = False,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
...@@ -408,16 +422,40 @@ def test_multi_query_kv_attention( ...@@ -408,16 +422,40 @@ def test_multi_query_kv_attention(
# Handle MQA and GQA # Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) alibi_bias = None
output = xops.memory_efficient_attention_forward( if use_alibi:
query.unsqueeze(0), alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
key.unsqueeze(0), attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
value.unsqueeze(0), seq_lens)
attn_bias=attn_bias, output = torch.empty_like(query)
p=0.0, start = 0
scale=scale, # Dynamic sequence length not supported with custom attn_bias.
) for i, seq_len in enumerate(seq_lens):
output = output.squeeze(0) end = start + seq_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
output[start:end].copy_(out.view_as(query[start:end]))
start += seq_len
# xformers.AttentionBias to Tensor for use in reference impl.
alibi_bias = [
b.materialize(b.shape, device=device).squeeze() for b in attn_bias
]
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
)
output = output.squeeze(0)
cu_seq_lens = [0] cu_seq_lens = [0]
for seq_len in seq_lens: for seq_len in seq_lens:
...@@ -428,8 +466,37 @@ def test_multi_query_kv_attention( ...@@ -428,8 +466,37 @@ def test_multi_query_kv_attention(
key, key,
value, value,
scale, scale,
alibi_bias,
dtype, dtype,
) )
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention_with_alibi(
num_seqs: int,
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
return test_multi_query_kv_attention(
num_seqs,
num_heads,
head_size,
dtype,
seed,
device,
use_alibi=True,
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock, patch from unittest.mock import patch
import pytest import pytest
import torch import torch
from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.openvino import OpenVinoPlatform
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -23,86 +22,117 @@ def clear_cache(): ...@@ -23,86 +22,117 @@ def clear_cache():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"] if not current_platform() else ["ROCM_FLASH"]) "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"] if not current_platform.is_rocm() else ["ROCM_FLASH"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"]) @pytest.mark.parametrize("use_v1", [True, False])
def test_env(name: str, device: str, monkeypatch): @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_env(
name: str,
use_v1: bool,
device: str,
monkeypatch: pytest.MonkeyPatch,
):
"""Test that the attention selector can be set via environment variable. """Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend. Note that we do not test FlashAttn because it is the default backend.
""" """
override_backend_env_variable(monkeypatch, name) with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
if device == "cpu": m.setenv(STR_BACKEND_ENV_VAR, name)
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, 16, if device == "cpu":
False)
assert backend.get_name() == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.current_platform",
OpenVinoPlatform()), patch.dict('sys.modules',
{'openvino': Mock()}):
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "OPENVINO"
else:
if name in ["XFORMERS", "FLASHINFER"]:
with patch("vllm.attention.selector.current_platform", with patch("vllm.attention.selector.current_platform",
CudaPlatform()): CpuPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, backend = get_attn_backend(16, torch.float16, torch.float16,
16, False) 16, False)
assert backend.get_name() == name assert backend.get_name() == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform",
def test_flash_attn(monkeypatch): RocmPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
16, False)
EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
assert backend.get_name() == EXPECTED
else:
if name in ["XFORMERS", "FLASHINFER"]:
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float16,
torch.float16, 16, False)
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name
assert backend.get_name() == EXPECTED
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
"""Test FlashAttn validation.""" """Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to # TODO: When testing for v1, pipe in `use_v1` as an argument to
# get_attn_backend # get_attn_backend
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
# Unsupported CUDA arch # Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)): monkeypatch.setattr(torch.cuda, "get_device_capability", lambda:
(7, 5))
backend = get_attn_backend(16, torch.float16, None, 16, False) backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported data type # Reset the monkeypatch for subsequent tests
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False) monkeypatch.undo()
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type # Unsupported data type
backend = get_attn_backend(16, torch.float16, "fp8", 16, False) backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported block size # Unsupported kv cache data type
backend = get_attn_backend(16, torch.float16, None, 8, False) backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL assert backend.get_name() != STR_FLASH_ATTN_VAL
# flash-attn is not installed # Unsupported block size
with patch.dict('sys.modules', {'vllm_flash_attn': None}): backend = get_attn_backend(16, torch.float16, None, 8, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# flash-attn is not installed
import sys
original_module = sys.modules.get('vllm_flash_attn')
monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
backend = get_attn_backend(16, torch.float16, None, 16, False) backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported head size # Restore the original module if it existed
backend = get_attn_backend(17, torch.float16, None, 16, False) if original_module is not None:
assert backend.get_name() != STR_FLASH_ATTN_VAL monkeypatch.setitem(sys.modules, 'vllm_flash_attn',
original_module)
else:
monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
# Unsupported head size
backend = get_attn_backend(17, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention @pytest.mark.parametrize("use_v1", [True, False])
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True) def test_invalid_env(use_v1: bool, monkeypatch: pytest.MonkeyPatch):
assert backend.get_name() != STR_FLASH_ATTN_VAL
with monkeypatch.context() as m, patch(
"vllm.attention.selector.current_platform", CudaPlatform()):
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
def test_invalid_env(monkeypatch): # Test with head size 32
"""Ignore the invalid env variable if it is set."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
backend = get_attn_backend(32, torch.float16, None, 16, False) backend = get_attn_backend(32, torch.float16, None, 16, False)
assert backend.get_name() == "FLASH_ATTN" EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN"
assert backend.get_name() == EXPECTED
# when block size == 16, backend will fall back to XFORMERS # when block size == 16, backend will fall back to XFORMERS
backend = get_attn_backend(16, torch.float16, None, 16, False) # this behavior is not yet supported on V1.
assert backend.get_name() == "XFORMERS" if use_v1:
# TODO: support fallback on V1!
# https://github.com/vllm-project/vllm/issues/14524
pass
else:
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() == "XFORMERS"
...@@ -99,13 +99,8 @@ def test_fused_marlin_moe_awq( ...@@ -99,13 +99,8 @@ def test_fused_marlin_moe_awq(
num_bits=num_bits, num_bits=num_bits,
) )
torch_output = torch_moe( torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2),
a, score, topk, None)
w_ref1.transpose(1, 2),
w_ref2.transpose(1, 2),
score,
topk,
)
assert compute_max_diff(marlin_output, torch_output) < 4e-2 assert compute_max_diff(marlin_output, torch_output) < 4e-2
......
...@@ -30,8 +30,8 @@ M_moe = [1, 7, 83, 512, 2048] ...@@ -30,8 +30,8 @@ M_moe = [1, 7, 83, 512, 2048]
N_moe = [4608] # [128, 4608, 13824] N_moe = [4608] # [128, 4608, 13824]
K_moe = [7168] # [256, 7168, 13824] K_moe = [7168] # [256, 7168, 13824]
BLOCK_SIZE = [[128, 128]] BLOCK_SIZE = [[128, 128]]
E = [256] # [8, 24, 128, 256] E = [8, 24] # [8, 24, 128, 256]
TOP_KS = [1] # [1, 2, 6] TOP_KS = [2] # [1, 2, 6]
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
SEEDS = [0] SEEDS = [0]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import random import random
from typing import List, Optional, Tuple from typing import Optional
import pytest import pytest
import torch import torch
...@@ -87,8 +87,8 @@ def ref_single_query_cached_kv_attention( ...@@ -87,8 +87,8 @@ def ref_single_query_cached_kv_attention(
block_table = block_tables_lst[i] block_table = block_tables_lst[i]
seq_len = int(seq_lens_lst[i]) seq_len = int(seq_lens_lst[i])
keys_lst: List[torch.Tensor] = [] keys_lst: list[torch.Tensor] = []
values_lst: List[torch.Tensor] = [] values_lst: list[torch.Tensor] = []
for j in range(seq_len): for j in range(seq_len):
block_number = int(block_table[j // block_size]) block_number = int(block_table[j // block_size])
block_offset = j % block_size block_offset = j % block_size
...@@ -162,7 +162,7 @@ def test_paged_attention( ...@@ -162,7 +162,7 @@ def test_paged_attention(
kv_cache_factory, kv_cache_factory,
version: str, version: str,
num_seqs: int, num_seqs: int,
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
use_alibi: bool, use_alibi: bool,
block_size: int, block_size: int,
...@@ -331,7 +331,7 @@ def test_paged_attention( ...@@ -331,7 +331,7 @@ def test_paged_attention(
def ref_multi_query_kv_attention( def ref_multi_query_kv_attention(
cu_seq_lens: List[int], cu_seq_lens: list[int],
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
...@@ -376,7 +376,7 @@ def ref_multi_query_kv_attention( ...@@ -376,7 +376,7 @@ def ref_multi_query_kv_attention(
@torch.inference_mode() @torch.inference_mode()
def test_varlen_blocksparse_attention_prefill( def test_varlen_blocksparse_attention_prefill(
num_seqs: int, num_seqs: int,
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
blocksparse_local_blocks: int, blocksparse_local_blocks: int,
blocksparse_vert_stride: int, blocksparse_vert_stride: int,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import random import random
from typing import List, Tuple
import pytest import pytest
import torch import torch
...@@ -9,7 +8,6 @@ import torch ...@@ -9,7 +8,6 @@ import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import align_to_256bytes
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
...@@ -75,7 +73,7 @@ def test_copy_blocks( ...@@ -75,7 +73,7 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings) src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping: List[Tuple[int, int]] = [] block_mapping: list[tuple[int, int]] = []
for i in range(num_mappings): for i in range(num_mappings):
src = src_blocks[i] src = src_blocks[i]
dst1 = dst_blocks[2 * i] dst1 = dst_blocks[2 * i]
...@@ -160,19 +158,20 @@ def test_reshape_and_cache( ...@@ -160,19 +158,20 @@ def test_reshape_and_cache(
device) device)
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
# Clone the KV caches. # Clone the KV caches.
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache) ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item())
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache) ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item())
else: else:
cloned_key_cache = key_cache.clone() cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone() cloned_value_cache = value_cache.clone()
# Using default kv_scale
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Call the reshape_and_cache kernel. # Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache, opcheck(torch.ops._C_cache_ops.reshape_and_cache,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
...@@ -183,9 +182,9 @@ def test_reshape_and_cache( ...@@ -183,9 +182,9 @@ def test_reshape_and_cache(
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache, key_cache) ops.convert_fp8(result_key_cache, key_cache, k_scale.item())
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache, value_cache) ops.convert_fp8(result_value_cache, value_cache, v_scale.item())
# Run the reference implementation. # Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
...@@ -269,15 +268,16 @@ def test_reshape_and_cache_flash( ...@@ -269,15 +268,16 @@ def test_reshape_and_cache_flash(
del key_caches del key_caches
del value_caches del value_caches
k_scale = (key.amax() / 256.0).to(torch.float32) k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 256.0).to(torch.float32) v_scale = (value.amax() / 64.0).to(torch.float32)
# Clone the KV caches. # Clone the KV caches.
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype) ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(),
kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache, v_scale, ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(),
kv_cache_dtype) kv_cache_dtype)
else: else:
cloned_key_cache = key_cache.clone() cloned_key_cache = key_cache.clone()
...@@ -341,7 +341,7 @@ def test_reshape_and_cache_flash( ...@@ -341,7 +341,7 @@ def test_reshape_and_cache_flash(
@torch.inference_mode() @torch.inference_mode()
def test_swap_blocks( def test_swap_blocks(
kv_cache_factory, kv_cache_factory,
direction: Tuple[str, str], direction: tuple[str, str],
num_mappings: int, num_mappings: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
...@@ -452,22 +452,13 @@ def _create_mla_cache( ...@@ -452,22 +452,13 @@ def _create_mla_cache(
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
align_cache: bool,
) -> torch.Tensor: ) -> torch.Tensor:
cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
return torch.zeros(num_blocks,
if align_cache: block_size,
alloc_entry_size = align_to_256bytes(entry_size, cache_dtype) entry_size,
alloc_shape = (num_blocks, block_size, alloc_entry_size) dtype=cache_dtype,
cache_full = torch.zeros(alloc_shape, dtype=cache_dtype, device=device) device=device)
cache = cache_full[..., :entry_size]
else:
cache = torch.zeros(num_blocks,
block_size,
entry_size,
dtype=cache_dtype,
device=device)
return cache
def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str): def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
...@@ -490,7 +481,6 @@ def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str): ...@@ -490,7 +481,6 @@ def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("align_cache", [False])
@torch.inference_mode() @torch.inference_mode()
def test_concat_and_cache_mla( def test_concat_and_cache_mla(
kv_lora_rank: int, kv_lora_rank: int,
...@@ -502,7 +492,6 @@ def test_concat_and_cache_mla( ...@@ -502,7 +492,6 @@ def test_concat_and_cache_mla(
seed: int, seed: int,
device: str, device: str,
kv_cache_dtype: str, kv_cache_dtype: str,
align_cache: bool,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
...@@ -522,7 +511,7 @@ def test_concat_and_cache_mla( ...@@ -522,7 +511,7 @@ def test_concat_and_cache_mla(
scale = torch.tensor(0.1, dtype=torch.float32, device=device) scale = torch.tensor(0.1, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache) kv_cache_dtype, device)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(num_tokens): for i in range(num_tokens):
...@@ -578,7 +567,6 @@ def test_concat_and_cache_mla( ...@@ -578,7 +567,6 @@ def test_concat_and_cache_mla(
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("align_cache", [False, True])
@torch.inference_mode() @torch.inference_mode()
def test_copy_blocks_mla( def test_copy_blocks_mla(
kv_lora_rank: int, kv_lora_rank: int,
...@@ -590,7 +578,6 @@ def test_copy_blocks_mla( ...@@ -590,7 +578,6 @@ def test_copy_blocks_mla(
seed: int, seed: int,
device: str, device: str,
kv_cache_dtype: str, kv_cache_dtype: str,
align_cache: bool,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
...@@ -600,7 +587,7 @@ def test_copy_blocks_mla( ...@@ -600,7 +587,7 @@ def test_copy_blocks_mla(
kv_caches = [] kv_caches = []
for _ in range(num_layers): for _ in range(num_layers):
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache) kv_cache_dtype, device)
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype) _fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
kv_caches.append(kv_cache) kv_caches.append(kv_cache)
...@@ -644,7 +631,6 @@ def test_copy_blocks_mla( ...@@ -644,7 +631,6 @@ def test_copy_blocks_mla(
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("align_cache", [False, True])
@torch.inference_mode() @torch.inference_mode()
def test_swap_blocks_mla( def test_swap_blocks_mla(
kv_lora_rank: int, kv_lora_rank: int,
...@@ -655,7 +641,6 @@ def test_swap_blocks_mla( ...@@ -655,7 +641,6 @@ def test_swap_blocks_mla(
seed: int, seed: int,
device: str, device: str,
kv_cache_dtype: str, kv_cache_dtype: str,
align_cache: bool,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
...@@ -663,9 +648,9 @@ def test_swap_blocks_mla( ...@@ -663,9 +648,9 @@ def test_swap_blocks_mla(
entry_size = kv_lora_rank + qk_rope_head_dim entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache) kv_cache_dtype, device)
dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache) kv_cache_dtype, device)
_fill_mla_cache(src_cache, kv_cache_dtype) _fill_mla_cache(src_cache, kv_cache_dtype)
_fill_mla_cache(dst_cache, kv_cache_dtype) _fill_mla_cache(dst_cache, kv_cache_dtype)
...@@ -685,8 +670,6 @@ def test_swap_blocks_mla( ...@@ -685,8 +670,6 @@ def test_swap_blocks_mla(
torch.ops._C_cache_ops.swap_blocks, torch.ops._C_cache_ops.swap_blocks,
(src_cache, dst_cache, block_mapping_tensor), (src_cache, dst_cache, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS, test_utils=DEFAULT_OPCHECK_TEST_UTILS,
cond=(kv_lora_rank == KV_LORA_RANKS[0]
and qk_rope_head_dim == QK_ROPE_HEAD_DIMS[0]),
) )
ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor) ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)
...@@ -697,3 +680,75 @@ def test_swap_blocks_mla( ...@@ -697,3 +680,75 @@ def test_swap_blocks_mla(
dst_cache[dst].cpu(), dst_cache[dst].cpu(),
msg=f"Block {src} from src should have been swapped to block " msg=f"Block {src} from src should have been swapped to block "
f"{dst} in dst_cache.") f"{dst} in dst_cache.")
@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype",
["auto"]) # You can also test "fp8" if needed.
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
num_blocks, max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
seq_len_tensor = torch.randint(0,
max_seq_len + 1, (batch_size, ),
device=device)
total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1),
dtype=torch.int32,
device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
print("seq_len_tensor", seq_len_tensor)
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
block_table = torch.empty((batch_size, num_blocks),
dtype=torch.int32,
device=device)
for b in range(batch_size):
perm = torch.randperm(num_blocks, device=device)
block_table[b, :] = perm
dst = torch.zeros((total_tokens, entry_size),
dtype=src_cache.dtype,
device=device)
expected_batches = []
for b in range(batch_size):
s = seq_len_tensor[b]
if s == 0:
continue
tot = tot_blocks_tensor[b]
blocks = block_table[b, :tot].tolist()
gathered_rows = []
for i in range(tot - 1):
gathered_rows.append(src_cache[blocks[i]])
remaining = s - (tot - 1) * block_size
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
batch_expected = torch.cat(gathered_rows, dim=0)
expected_batches.append(batch_expected)
expected = torch.cat(expected_batches, dim=0)
opcheck(
torch.ops._C_cache_ops.gather_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
torch.testing.assert_close(dst, expected)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple from typing import Optional
import pytest import pytest
import torch import torch
...@@ -25,7 +25,7 @@ DTYPES = [torch.float16, torch.bfloat16] ...@@ -25,7 +25,7 @@ DTYPES = [torch.float16, torch.bfloat16]
@torch.inference_mode() @torch.inference_mode()
def test_merge_kernel( def test_merge_kernel(
num_tokens: int, num_tokens: int,
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
): ):
...@@ -85,8 +85,8 @@ CASES = [ ...@@ -85,8 +85,8 @@ CASES = [
@pytest.mark.parametrize("fa_version", [2, 3]) @pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode() @torch.inference_mode()
def test_cascade( def test_cascade(
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int], seq_lens_and_common_prefix: tuple[list[tuple[int, int]], int],
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
block_size: int, block_size: int,
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_cutlass.py`. Run `pytest tests/kernels/test_cutlass.py`.
""" """
from typing import Type, Optional
import pytest import pytest
import torch import torch
...@@ -82,7 +81,7 @@ def cutlass_fp8_gemm_helper(m: int, ...@@ -82,7 +81,7 @@ def cutlass_fp8_gemm_helper(m: int,
a_scale_group_shape: tuple, a_scale_group_shape: tuple,
b_scale_group_shape: tuple, b_scale_group_shape: tuple,
use_bias: bool, use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16, out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda"): device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization # Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization. # and per-output channel weight quantization.
...@@ -120,7 +119,7 @@ def cutlass_int8_gemm_helper(m: int, ...@@ -120,7 +119,7 @@ def cutlass_int8_gemm_helper(m: int,
a_scale_group_shape: tuple, a_scale_group_shape: tuple,
b_scale_group_shape: tuple, b_scale_group_shape: tuple,
use_bias: bool, use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16, out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda"): device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization # Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization. # and per-output channel weight quantization.
...@@ -198,7 +197,7 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape, ...@@ -198,7 +197,7 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape, def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape, b_scale_group_shape,
out_dtype: Type[torch.dtype], out_dtype: type[torch.dtype],
use_bias: bool): use_bias: bool):
cutlass_int8_gemm_helper(512, cutlass_int8_gemm_helper(512,
512, 512,
...@@ -208,26 +207,25 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape, ...@@ -208,26 +207,25 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
use_bias, use_bias,
out_dtype=out_dtype) out_dtype=out_dtype)
@pytest.mark.parametrize("a_scale_group_shape",
# @pytest.mark.parametrize("a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) @pytest.mark.parametrize("b_scale_group_shape",
# @pytest.mark.parametrize("b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.skipif(not current_platform.has_device_capability(89),
# @pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.")
# reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
# def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape, b_scale_group_shape,
# b_scale_group_shape, out_dtype: type[torch.dtype],
# out_dtype: Type[torch.dtype], use_bias: bool):
# use_bias: bool): cutlass_fp8_gemm_helper(512,
# cutlass_fp8_gemm_helper(512, 512,
# 512, 512,
# 512, a_scale_group_shape,
# a_scale_group_shape, b_scale_group_shape,
# b_scale_group_shape, use_bias,
# use_bias, out_dtype=out_dtype)
# out_dtype=out_dtype)
# @pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape", # @pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
...@@ -238,7 +236,7 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape, ...@@ -238,7 +236,7 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
# reason="FP8 blockwise is not supported on this GPU type.") # reason="FP8 blockwise is not supported on this GPU type.")
# def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape, # def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
# b_scale_group_shape, # b_scale_group_shape,
# out_dtype: Type[torch.dtype], # out_dtype: type[torch.dtype],
# use_bias: bool): # use_bias: bool):
# cutlass_fp8_gemm_helper(512, # cutlass_fp8_gemm_helper(512,
# 512, # 512,
...@@ -271,15 +269,15 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape, ...@@ -271,15 +269,15 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
# @pytest.mark.parametrize("use_bias", [True, False]) # @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
# def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape, # def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
use_bias: bool, device: str): # use_bias: bool, device: str):
cutlass_int8_gemm_helper(512, # cutlass_int8_gemm_helper(512,
512, # 512,
512, # 512,
a_scale_group_shape, # a_scale_group_shape,
b_scale_group_shape, # b_scale_group_shape,
use_bias, # use_bias,
out_dtype=torch.bfloat16, # out_dtype=torch.bfloat16,
device=device) # device=device)
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_semi_structured.py`. Run `pytest tests/kernels/test_semi_structured.py`.
""" """
from typing import Tuple, Type
import pytest import pytest
import torch import torch
...@@ -79,7 +78,7 @@ def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor, ...@@ -79,7 +78,7 @@ def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
def make_rand_sparse_tensors( def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int dtype: torch.dtype, m: int, n: int, k: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') a = torch.randn((m, k), device='cuda')
b = torch.randn((n, k), device='cuda').t() b = torch.randn((n, k), device='cuda').t()
...@@ -167,7 +166,7 @@ MNK_FACTORS = [ ...@@ -167,7 +166,7 @@ MNK_FACTORS = [
@pytest.mark.parametrize("m, n, k", MNK_FACTORS) @pytest.mark.parametrize("m, n, k", MNK_FACTORS)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype], def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype],
use_bias: bool): use_bias: bool):
# Create tensors # Create tensors
......
...@@ -22,6 +22,16 @@ from vllm.config import VllmConfig, set_current_vllm_config ...@@ -22,6 +22,16 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform from vllm.platforms import current_platform
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
Encoder-decoder is only supported on V0, so set
VLLM_USE_V1=0 for all tests in the module.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
# List of support backends for encoder/decoder models # List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
HEAD_SIZES = [64, 256] HEAD_SIZES = [64, 256]
...@@ -243,7 +253,7 @@ def _decoder_attn_setup( ...@@ -243,7 +253,7 @@ def _decoder_attn_setup(
test_pt: TestPoint, test_pt: TestPoint,
test_rsrcs: TestResources, test_rsrcs: TestResources,
block_base_addr: int = 0, block_base_addr: int = 0,
) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: ) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]:
''' '''
Set up test vectors & data structures for self-attention test. Set up test vectors & data structures for self-attention test.
...@@ -421,7 +431,7 @@ def _enc_dec_cross_attn_setup_reuses_query( ...@@ -421,7 +431,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
test_pt: TestPoint, test_pt: TestPoint,
test_rsrcs: TestResources, test_rsrcs: TestResources,
block_base_addr: int = 0, block_base_addr: int = 0,
) -> Tuple[PhaseTestParameters, PhaseTestParameters]: ) -> tuple[PhaseTestParameters, PhaseTestParameters]:
''' '''
Set up test vectors & data structures for cross-attention test. Set up test vectors & data structures for cross-attention test.
...@@ -644,11 +654,7 @@ def _run_encoder_attention_test( ...@@ -644,11 +654,7 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape. # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view( reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size) -1, test_pt.num_heads * test_pt.head_size)
return attn.forward( return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)
reshaped_query, packed_qkv.key, packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device), attn_metadata)
def _run_decoder_self_attention_test( def _run_decoder_self_attention_test(
...@@ -682,7 +688,6 @@ def _run_decoder_self_attention_test( ...@@ -682,7 +688,6 @@ def _run_decoder_self_attention_test(
& attn_metadata & attn_metadata
''' '''
attn = test_rsrcs.attn attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None assert packed_qkv is not None
with set_forward_context(attn_metadata, vllm_config): with set_forward_context(attn_metadata, vllm_config):
...@@ -695,8 +700,7 @@ def _run_decoder_self_attention_test( ...@@ -695,8 +700,7 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape. # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view( reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size) -1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value, return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)
kv_cache, attn_metadata)
def _run_encoder_decoder_cross_attention_test( def _run_encoder_decoder_cross_attention_test(
...@@ -744,7 +748,6 @@ def _run_encoder_decoder_cross_attention_test( ...@@ -744,7 +748,6 @@ def _run_encoder_decoder_cross_attention_test(
assert decoder_test_params.packed_qkvo.packed_qkv is not None assert decoder_test_params.packed_qkvo.packed_qkv is not None
attn = test_rsrcs.attn attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
if cross_test_params is None: if cross_test_params is None:
key = None key = None
value = None value = None
...@@ -762,8 +765,7 @@ def _run_encoder_decoder_cross_attention_test( ...@@ -762,8 +765,7 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape. # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size) -1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query, key, value, kv_cache, return attn.forward(reshaped_query, key, value)
attn_metadata)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple from typing import Optional
import pytest import pytest
import torch import torch
...@@ -8,8 +8,8 @@ import torch ...@@ -8,8 +8,8 @@ import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform(): if current_platform.is_rocm():
import flash_attn from flash_attn import flash_attn_varlen_func
else: else:
from vllm.vllm_flash_attn import (fa_version_unsupported_reason, from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func, flash_attn_varlen_func,
...@@ -20,6 +20,7 @@ NUM_HEADS = [(4, 4), (8, 2), (16, 2)] ...@@ -20,6 +20,7 @@ NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16] DTYPES = [torch.float16, torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn]
# one value large enough to test overflow in index calculation. # one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check # one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048] NUM_BLOCKS = [32768, 2048]
...@@ -29,8 +30,8 @@ def ref_paged_attn( ...@@ -29,8 +30,8 @@ def ref_paged_attn(
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
query_lens: List[int], query_lens: list[int],
kv_lens: List[int], kv_lens: list[int],
block_tables: torch.Tensor, block_tables: torch.Tensor,
scale: float, scale: float,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
...@@ -40,7 +41,7 @@ def ref_paged_attn( ...@@ -40,7 +41,7 @@ def ref_paged_attn(
block_tables = block_tables.cpu().numpy() block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape _, block_size, num_kv_heads, head_size = key_cache.shape
outputs: List[torch.Tensor] = [] outputs: list[torch.Tensor] = []
start_idx = 0 start_idx = 0
for i in range(num_seqs): for i in range(num_seqs):
query_len = query_lens[i] query_len = query_lens[i]
...@@ -79,91 +80,124 @@ def ref_paged_attn( ...@@ -79,91 +80,124 @@ def ref_paged_attn(
return torch.cat(outputs, dim=0) return torch.cat(outputs, dim=0)
if not current_platform(): @pytest.mark.skipif(current_platform.is_rocm(),
@pytest.mark.parametrize("use_out", [True, False]) reason="flash_attn_with_paged_kv is not supported on ROCm.")
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) @pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("sliding_window", [None, 256]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("fa_version", [2, 3]) @pytest.mark.parametrize("sliding_window", [None, 256])
@torch.inference_mode() @pytest.mark.parametrize("fa_version", [2, 3])
def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("q_dtype", QDTYPES)
use_out: bool, @torch.inference_mode()
kv_lens: List[int], def test_flash_attn_with_paged_kv(
num_heads: Tuple[int, int], use_out: bool,
head_size: int, kv_lens: list[int],
dtype: torch.dtype, num_heads: tuple[int, int],
block_size: int, head_size: int,
soft_cap: Optional[float], dtype: torch.dtype,
num_blocks: int, block_size: int,
sliding_window: Optional[int], soft_cap: Optional[float],
fa_version: int, num_blocks: int,
) -> None: sliding_window: Optional[int],
torch.set_default_device("cuda") fa_version: int,
if not is_fa_version_supported(fa_version): q_dtype: Optional[torch.dtype],
pytest.skip(f"Flash attention version {fa_version} not supported due " ) -> None:
f"to: \"{fa_version_unsupported_reason(fa_version)}\"") torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
current_platform.seed_everything(0) pytest.skip(f"Flash attention version {fa_version} not supported due "
num_seqs = len(kv_lens) f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
num_query_heads = num_heads[0] if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
num_kv_heads = num_heads[1] pytest.skip("Flash attention with quantized inputs is only "
assert num_query_heads % num_kv_heads == 0 "supported on version 3 with bfloat16 base type")
max_kv_len = max(kv_lens)
scale = head_size**-0.5 current_platform.seed_everything(0)
window_size = ((sliding_window - 1, 0) if sliding_window is not None else num_seqs = len(kv_lens)
(-1, -1)) num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) assert num_query_heads % num_kv_heads == 0
key_cache = torch.randn(num_blocks, max_kv_len = max(kv_lens)
block_size, scale = head_size**-0.5
num_kv_heads, window_size = ((sliding_window - 1, 0) if sliding_window is not None else
head_size, (-1, -1))
dtype=dtype)
value_cache = torch.randn_like(key_cache) query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) key_cache = torch.randn(num_blocks,
block_size,
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size num_kv_heads,
block_tables = torch.randint(0, head_size,
num_blocks, dtype=dtype)
(num_seqs, max_num_blocks_per_seq), value_cache = torch.randn_like(key_cache)
dtype=torch.int32) kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
q = query.unsqueeze(1) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
out = torch.empty_like(q) if use_out else None block_tables = torch.randint(0,
output = flash_attn_with_kvcache( num_blocks,
q=q, (num_seqs, max_num_blocks_per_seq),
k_cache=key_cache, dtype=torch.int32)
v_cache=value_cache,
out=out, q = query.unsqueeze(1)
softmax_scale=scale, out = torch.empty_like(q) if use_out else None
causal=True,
block_table=block_tables, maybe_quantized_query = q
cache_seqlens=kv_lens_tensor, maybe_quantized_key_cache = key_cache
softcap=soft_cap if soft_cap is not None else 0, maybe_quantized_value_cache = value_cache
window_size=window_size, q_descale = None
fa_version=fa_version, k_descale = None
) v_descale = None
output = output if not use_out else out if q_dtype is not None:
output = output.squeeze(1) # QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query = query.to(q_dtype)
ref_output = ref_paged_attn(query=query, maybe_quantized_key_cache = key_cache.to(q_dtype)
key_cache=key_cache, maybe_quantized_value_cache = value_cache.to(q_dtype)
value_cache=value_cache,
query_lens=[1] * num_seqs, scale_shape = (num_seqs, num_kv_heads)
kv_lens=kv_lens, q_descale = torch.ones(scale_shape, dtype=torch.float32)
block_tables=block_tables, k_descale = torch.ones(scale_shape, dtype=torch.float32)
scale=scale, v_descale = torch.ones(scale_shape, dtype=torch.float32)
soft_cap=soft_cap,
sliding_window=sliding_window) output = flash_attn_with_kvcache(
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ q=maybe_quantized_query,
f"{torch.max(torch.abs(output - ref_output))}" k_cache=maybe_quantized_key_cache,
v_cache=maybe_quantized_value_cache,
out=out,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size,
fa_version=fa_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
output = output if not use_out else out
output = output.squeeze(1)
atol, rtol = 1.5e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
@pytest.mark.skipif(current_platform.is_rocm(),
reason="varlen_with_paged_kv is not supported on ROCm.")
@pytest.mark.parametrize("use_out", [True, False]) @pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("seq_lens", @pytest.mark.parametrize("seq_lens",
[[(1, 1328), (5, 18), [[(1, 1328), (5, 18),
...@@ -176,11 +210,12 @@ if not current_platform(): ...@@ -176,11 +210,12 @@ if not current_platform():
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("fa_version", [2, 3]) @pytest.mark.parametrize("fa_version", [2, 3])
@pytest.mark.parametrize("q_dtype", QDTYPES)
@torch.inference_mode() @torch.inference_mode()
def test_varlen_with_paged_kv( def test_varlen_with_paged_kv(
use_out: bool, use_out: bool,
seq_lens: List[Tuple[int, int]], seq_lens: list[tuple[int, int]],
num_heads: Tuple[int, int], num_heads: tuple[int, int],
head_size: int, head_size: int,
sliding_window: Optional[int], sliding_window: Optional[int],
dtype: torch.dtype, dtype: torch.dtype,
...@@ -188,11 +223,15 @@ def test_varlen_with_paged_kv( ...@@ -188,11 +223,15 @@ def test_varlen_with_paged_kv(
soft_cap: Optional[float], soft_cap: Optional[float],
num_blocks: int, num_blocks: int,
fa_version: int, fa_version: int,
q_dtype: Optional[torch.dtype],
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version): if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due " pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"") f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
pytest.skip("Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
...@@ -219,9 +258,6 @@ def test_varlen_with_paged_kv( ...@@ -219,9 +258,6 @@ def test_varlen_with_paged_kv(
cu_query_lens = torch.tensor([0] + query_lens, cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0, dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32) dtype=torch.int32)
cu_kv_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32) kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
...@@ -231,42 +267,43 @@ def test_varlen_with_paged_kv( ...@@ -231,42 +267,43 @@ def test_varlen_with_paged_kv(
dtype=torch.int32) dtype=torch.int32)
out = torch.empty_like(query) if use_out else None out = torch.empty_like(query) if use_out else None
if current_platform():
output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
out=out,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
# fa_version=fa_version,
)
else:
output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
out=out,
cu_seqlens_q=cu_query_lens,
seqused_k=kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version,
)
maybe_quantized_query = query
maybe_quantized_key_cache = key_cache
maybe_quantized_value_cache = value_cache
q_descale = None
k_descale = None
v_descale = None
if q_dtype is not None:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query = query.to(q_dtype)
maybe_quantized_key_cache = key_cache.to(q_dtype)
maybe_quantized_value_cache = value_cache.to(q_dtype)
scale_shape = (num_seqs, num_kv_heads)
q_descale = torch.ones(scale_shape, dtype=torch.float32)
k_descale = torch.ones(scale_shape, dtype=torch.float32)
v_descale = torch.ones(scale_shape, dtype=torch.float32)
output = flash_attn_varlen_func(
q=maybe_quantized_query,
k=maybe_quantized_key_cache,
v=maybe_quantized_value_cache,
out=out,
cu_seqlens_q=cu_query_lens,
seqused_k=kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
output = output if not use_out else out output = output if not use_out else out
ref_output = ref_paged_attn( ref_output = ref_paged_attn(
...@@ -280,5 +317,8 @@ def test_varlen_with_paged_kv( ...@@ -280,5 +317,8 @@ def test_varlen_with_paged_kv(
sliding_window=sliding_window, sliding_window=sliding_window,
soft_cap=soft_cap, soft_cap=soft_cap,
) )
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ atol, rtol = 1.5e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
\ No newline at end of file
# Adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla.py
# SPDX-License-Identifier: Apache-2.0
import math
import random
import pytest
import torch
import triton
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
x, y = x.double(), y.double()
cos_diff = 1 - 2 * (x * y).sum().item() / max(
(x * x + y * y).sum().item(), 1e-12)
assert cos_diff < 1e-5
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
if not is_flashmla_supported()[0] else "FlashMLA is supported"
@pytest.mark.skipif(not is_flashmla_supported()[0],
reason=FLASH_MLA_UNSUPPORTED_REASON)
@pytest.mark.parametrize("b", [128])
@pytest.mark.parametrize("s_q", [1, 2])
@pytest.mark.parametrize("mean_sk", [4096, 8192])
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
@pytest.mark.parametrize("h_kv", [1])
@pytest.mark.parametrize("d", [576])
@pytest.mark.parametrize("dv", [512])
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True])
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
varlen):
# TODO: parametrize using pytest
dtype = torch.bfloat16
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
f"{d=}, {dv=}, {causal=}, {varlen=}")
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2),
s_q)
total_seqlens = cache_seqlens.sum().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
q = torch.randn(b, s_q, h_q, d)
block_table = torch.arange(b * max_seqlen_pad // block_size,
dtype=torch.int32).view(
b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv,
d)[i, cache_seqlens[i].item():] = float("nan")
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv)
def flash_mla():
return flash_mla_with_kvcache(
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
)
def scaled_dot_product_attention(query, key, value, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k,
dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
ref_O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
is_causal=causal,
)
out[i] = ref_O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
cal_diff(out_flash, out_torch, "out")
cal_diff(lse_flash, lse_torch, "lse")
t = triton.testing.do_bench(flash_mla, fast_flush=False)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} "
f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple, Union from typing import Optional, Union
import pytest import pytest
import torch import torch
...@@ -39,7 +39,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: ...@@ -39,7 +39,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
def ref_rms_norm(rms_norm_layer: RMSNorm, def ref_rms_norm(rms_norm_layer: RMSNorm,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor]) \ residual: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, Optional[torch.Tensor]]: -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if residual is not None: if residual is not None:
residual = residual.clone() residual = residual.clone()
out, residual = rms_norm_layer.forward_native(x, residual) out, residual = rms_norm_layer.forward_native(x, residual)
...@@ -54,7 +54,7 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, ...@@ -54,7 +54,7 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \ scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if scale_ub is not None: if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn assert quant_dtype == torch.float8_e4m3fn
...@@ -78,7 +78,7 @@ def ref_impl(rms_norm_layer: RMSNorm, ...@@ -78,7 +78,7 @@ def ref_impl(rms_norm_layer: RMSNorm,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \ scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype, return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype,
residual, scale_ub) residual, scale_ub)
...@@ -88,7 +88,7 @@ def ops_dynamic_per_token_quant(weight: torch.Tensor, ...@@ -88,7 +88,7 @@ def ops_dynamic_per_token_quant(weight: torch.Tensor,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \ scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if residual is not None: if residual is not None:
residual = residual.clone() residual = residual.clone()
out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS, out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS,
...@@ -102,7 +102,7 @@ def ops_impl(weight: torch.Tensor, ...@@ -102,7 +102,7 @@ def ops_impl(weight: torch.Tensor,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \ scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual,
scale_ub) scale_ub)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from pathlib import Path from pathlib import Path
from typing import List
import pytest import pytest
import os import os
...@@ -10,23 +9,37 @@ from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize ...@@ -10,23 +9,37 @@ from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
import vllm._custom_ops as ops import vllm._custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import models_path_prefix from ..utils import models_path_prefix
# GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample") # GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
# GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
GGUF_SAMPLE = os.path.join(models_path_prefix, "Isotr0py/test-gguf-sample") GGUF_SAMPLE = os.path.join(models_path_prefix, "Isotr0py/test-gguf-sample")
GGUF_SAMPLE_MOE = os.path.join(models_path_prefix, "SzymonOzog/test-gguf-moe-sample")
def get_gguf_sample_tensors( def get_gguf_sample_tensors(
hidden_size: int, hidden_size: int,
quant_type: GGMLQuantizationType) -> List[ReaderTensor]: quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
sample_dir = GGUF_SAMPLE sample_dir = GGUF_SAMPLE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf" filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename sample_file = Path(sample_dir) / filename
return GGUFReader(sample_file).tensors return GGUFReader(sample_file).tensors
DTYPES = [torch.half] def get_gguf_MoE_tensors(
hidden_size: int,
quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
sample_dir = GGUF_SAMPLE_MOE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename
return GGUFReader(sample_file).tensors
DTYPES = [torch.half, torch.bfloat16, torch.float32]
# Hidden_size for testing, must match the sample file in HF repo, # Hidden_size for testing, must match the sample file in HF repo,
# we have `hidden_size = 256, 1024` for test in HF repo currently. # we have `hidden_size = 256, 1024` for test in HF repo currently.
HIDDEN_SIZES = [256, 1024] HIDDEN_SIZES = [256, 1024]
...@@ -56,7 +69,7 @@ QUANT_TYPES = [ ...@@ -56,7 +69,7 @@ QUANT_TYPES = [
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", [torch.half])
@pytest.mark.parametrize("quant_type", QUANT_TYPES) @pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode() @torch.inference_mode()
def test_dequantize(hidden_size: int, dtype: torch.dtype, def test_dequantize(hidden_size: int, dtype: torch.dtype,
...@@ -126,7 +139,64 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, ...@@ -126,7 +139,64 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
ref_output = x @ weight.T ref_output = x @ weight.T
qweight = torch.tensor(tensor.data, device="cuda") qweight = torch.tensor(tensor.data, device="cuda")
output = ops.ggml_mul_mat_a8(qweight, x, quant_type, output = ops.ggml_mul_mat_a8(qweight, x, quant_type, qweight.shape[0])
qweight.shape[0]).to(dtype) atols = {torch.half: 1, torch.bfloat16: 1.5, torch.float: 1.2}
# test matrix has inputs centered around 0 and lower precision from
# bfloat16 tends to accumulate and can greatly inflate rtol
# since outputs are also very close to 0
rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1}
torch.testing.assert_close(output,
ref_output,
atol=atols[dtype],
rtol=rtols[dtype])
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", [512])
@pytest.mark.parametrize("top_k", [4, 8])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize(
"quant_type",
[
# k-quants
GGMLQuantizationType.Q2_K,
GGMLQuantizationType.Q3_K,
GGMLQuantizationType.Q4_K,
GGMLQuantizationType.Q5_K,
GGMLQuantizationType.Q6_K,
# standard quants
GGMLQuantizationType.Q4_0,
GGMLQuantizationType.Q5_0,
GGMLQuantizationType.Q8_0,
])
@torch.inference_mode()
def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType, top_k: int):
current_platform.seed_everything(0)
H, E = 1024, 256
x = torch.rand((num_tokens, H), dtype=dtype, device="cuda")
topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype)
topk_ids = torch.randint(0, E, (num_tokens, top_k), device="cuda")
tensors = get_gguf_MoE_tensors(hidden_size, quant_type)
w13 = tensors[0]
w2 = tensors[1]
w13_dequant = torch.tensor(dequantize(w13.data, quant_type),
device="cuda").to(dtype)
w2_dequant = torch.tensor(dequantize(w2.data, quant_type),
device="cuda").to(dtype)
act = SiluAndMul()
output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"),
torch.tensor(w2.data,
device="cuda"), topk_weights,
topk_ids, quant_type, quant_type, act)
ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights,
topk_ids).reshape(output.shape)
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
...@@ -6,7 +6,7 @@ Run `pytest tests/kernels/test_machete_mm.py`. ...@@ -6,7 +6,7 @@ Run `pytest tests/kernels/test_machete_mm.py`.
import math import math
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import List, Optional, Tuple from typing import Optional
import pytest import pytest
import torch import torch
...@@ -45,7 +45,7 @@ MNK_SHAPES = [ ...@@ -45,7 +45,7 @@ MNK_SHAPES = [
(1024, 8192, 4096), (1024, 8192, 4096),
] ]
GROUP_SIZES_TO_TEST: List[Optional[int]] = [128, -1] GROUP_SIZES_TO_TEST: list[Optional[int]] = [128, -1]
@dataclass @dataclass
...@@ -75,7 +75,7 @@ class Tensors: ...@@ -75,7 +75,7 @@ class Tensors:
# Ch Scales Type, Tok Scales Type) # Ch Scales Type, Tok Scales Type)
# NOTE: None "Scale Type" means the act type is floating point # NOTE: None "Scale Type" means the act type is floating point
# None "Output Type" means the output type is the same as the act type # None "Output Type" means the output type is the same as the act type
TestTypeTuple = Tuple[List[torch.dtype], ScalarType, Optional[torch.dtype], TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype],
Optional[torch.dtype], bool] Optional[torch.dtype], bool]
TEST_TYPES = [ TEST_TYPES = [
# GPTQ style # GPTQ style
...@@ -136,7 +136,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): ...@@ -136,7 +136,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
return zps if zps is None else -1 * s * (zps.to(s.dtype)) return zps if zps is None else -1 * s * (zps.to(s.dtype))
def group_size_valid(shape: Tuple[int, int, int], def group_size_valid(shape: tuple[int, int, int],
group_size: Optional[int]) -> bool: group_size: Optional[int]) -> bool:
return group_size is None or group_size == -1 or group_size % shape[2] == 0 return group_size is None or group_size == -1 or group_size % shape[2] == 0
...@@ -166,7 +166,7 @@ def machete_quantize_and_pack(atype: torch.dtype, ...@@ -166,7 +166,7 @@ def machete_quantize_and_pack(atype: torch.dtype,
return w_ref, w_q_machete, w_s, w_zp return w_ref, w_q_machete, w_s, w_zp
def create_test_tensors(shape: Tuple[int, int, int], def create_test_tensors(shape: tuple[int, int, int],
types: TypeConfig, types: TypeConfig,
group_size: Optional[int], group_size: Optional[int],
subset_stride_factor: Optional[int] = None) -> Tensors: subset_stride_factor: Optional[int] = None) -> Tensors:
...@@ -265,7 +265,7 @@ def machete_mm_test_helper(types: TypeConfig, ...@@ -265,7 +265,7 @@ def machete_mm_test_helper(types: TypeConfig,
@pytest.mark.parametrize("types", TEST_TYPES) @pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_all_schedules(shape, types: TypeConfig): def test_machete_all_schedules(shape, types: TypeConfig):
group_sizes: List[Optional[int]] = [] group_sizes: list[Optional[int]] = []
if types.group_scale_type is None: if types.group_scale_type is None:
group_sizes = [None] group_sizes = [None]
else: else:
...@@ -294,7 +294,7 @@ def test_machete_all_schedules(shape, types: TypeConfig): ...@@ -294,7 +294,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
ids=lambda x: "x".join(str(v) for v in x)) ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("types", TEST_TYPES) @pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_heuristic(shape, types: TypeConfig): def test_machete_heuristic(shape, types: TypeConfig):
group_sizes: List[Optional[int]] = [] group_sizes: list[Optional[int]] = []
if types.group_scale_type is None: if types.group_scale_type is None:
group_sizes = [None] group_sizes = [None]
else: else:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import unittest import unittest
from typing import Tuple
import pytest import pytest
import torch import torch
...@@ -29,7 +28,7 @@ from vllm.utils import update_environment_variables ...@@ -29,7 +28,7 @@ from vllm.utils import update_environment_variables
def test_mixer2_gated_norm_multi_gpu( def test_mixer2_gated_norm_multi_gpu(
batch_size: int, batch_size: int,
seq_len: int, seq_len: int,
hidden_size_n_groups: Tuple[int, int], hidden_size_n_groups: tuple[int, int],
dtype: torch.dtype, dtype: torch.dtype,
device: str = 'cuda', device: str = 'cuda',
): ):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Dict, Tuple
import pytest import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -134,7 +132,7 @@ def generate_continous_batched_examples(example_lens_by_batch, ...@@ -134,7 +132,7 @@ def generate_continous_batched_examples(example_lens_by_batch,
# given a tuple of lengths for each example in the batch # given a tuple of lengths for each example in the batch
# e.g., example_lens=(8, 4) means take 8 samples from first eg, # e.g., example_lens=(8, 4) means take 8 samples from first eg,
# 4 examples from second eg, etc # 4 examples from second eg, etc
def get_continuous_batch(example_lens: Tuple[int, ...]): def get_continuous_batch(example_lens: tuple[int, ...]):
indices = [] indices = []
for i, x in enumerate(example_lens): for i, x in enumerate(example_lens):
...@@ -264,8 +262,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, ...@@ -264,8 +262,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# hold state during the cutting process so we know if an # hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle # example has been exhausted and needs to cycle
last_taken: Dict = {} # map: eg -> pointer to last taken sample last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
states = None states = None
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, for Y_min, cu_seqlens, sed_idx, (A, dt, X, B,
......
...@@ -3,8 +3,11 @@ ...@@ -3,8 +3,11 @@
Run `pytest tests/kernels/test_moe.py`. Run `pytest tests/kernels/test_moe.py`.
""" """
import pytest import pytest
import torch import torch
from torch.nn import Parameter
from torch.nn import functional as F
from transformers import MixtralConfig from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
...@@ -26,6 +29,7 @@ from vllm.platforms import current_platform ...@@ -26,6 +29,7 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
NUM_EXPERTS = [8, 64] NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
TOP_KS = [2, 6] TOP_KS = [2, 6]
...@@ -34,24 +38,64 @@ TOP_KS = [2, 6] ...@@ -34,24 +38,64 @@ TOP_KS = [2, 6]
@pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
def test_fused_moe( def test_fused_moe(
m: int, m: int,
n: int, n: int,
k: int, k: int,
e: int, e: int,
topk: int, topk: int,
ep_size: int,
dtype: torch.dtype, dtype: torch.dtype,
padding: bool,
): ):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk) if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randint(0,
e, (local_e, ),
device="cuda",
dtype=torch.int32)
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
else:
e_map = None
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
iterative_output = iterative_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
# Pad the weight if moe padding is enabled
if padding:
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache()
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache()
triton_output = fused_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
torch.testing.assert_close(iterative_output, torch.testing.assert_close(iterative_output,
torch_output, torch_output,
atol=2e-2, atol=2e-2,
...@@ -63,13 +107,14 @@ def test_fused_moe( ...@@ -63,13 +107,14 @@ def test_fused_moe(
@pytest.mark.parametrize("k", [128, 1024]) @pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8]) @pytest.mark.parametrize("weight_bits", [4, 8])
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype, group_size: int, has_zp: bool, ep_size: int, dtype: torch.dtype, group_size: int,
weight_bits: int): has_zp: bool, weight_bits: int):
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
...@@ -130,6 +175,25 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ...@@ -130,6 +175,25 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
if has_zp: if has_zp:
w_qzeros[expert_id] = qzeros w_qzeros[expert_id] = qzeros
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randint(0,
e, (local_e, ),
device="cuda",
dtype=torch.int32)
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1_ref = w1_ref[e_ids]
w2_ref = w2_ref[e_ids]
w1_qweight = w1_qweight[e_ids]
w2_qweight = w2_qweight[e_ids]
w1_scales = w1_scales[e_ids]
w2_scales = w2_scales[e_ids]
w1_qzeros = w1_qzeros[e_ids]
w2_qzeros = w2_qzeros[e_ids]
else:
e_map = None
triton_output = fused_moe(a, triton_output = fused_moe(a,
w1_qweight, w1_qweight,
w2_qweight, w2_qweight,
...@@ -138,19 +202,22 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ...@@ -138,19 +202,22 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
renormalize=False, renormalize=False,
use_int4_w4a16=weight_bits == 4, use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8, use_int8_w8a16=weight_bits == 8,
global_num_experts=e,
expert_map=e_map,
w1_scale=w1_scales, w1_scale=w1_scales,
w2_scale=w2_scales, w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None, w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None, w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size]) block_shape=[0, group_size])
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk) torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
@pytest.mark.parametrize("dtype", @pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16]) [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@torch.inference_mode() @torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype): def test_mixtral_moe(dtype: torch.dtype, padding: bool):
"""Make sure our Mixtral MoE implementation agrees with the one from """Make sure our Mixtral MoE implementation agrees with the one from
huggingface.""" huggingface."""
...@@ -164,6 +231,7 @@ def test_mixtral_moe(dtype: torch.dtype): ...@@ -164,6 +231,7 @@ def test_mixtral_moe(dtype: torch.dtype):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
params_dtype=dtype, params_dtype=dtype,
tp_size=1, tp_size=1,
dp_size=1,
).cuda() ).cuda()
# Load the weights # Load the weights
...@@ -179,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype): ...@@ -179,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype):
# vLLM uses 1D query [num_tokens, hidden_dim] # vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs = hf_inputs.flatten(0, 1) vllm_inputs = hf_inputs.flatten(0, 1)
# Pad the weight if moe padding is enabled
if padding:
vllm_moe.experts.w13_weight = Parameter(F.pad(
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128],
requires_grad=False)
torch.cuda.empty_cache()
vllm_moe.experts.w2_weight = Parameter(F.pad(
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
requires_grad=False)
torch.cuda.empty_cache()
# Run forward passes for both MoE blocks # Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(hf_inputs) hf_states, _ = hf_moe.forward(hf_inputs)
vllm_states = vllm_moe.forward(vllm_inputs) vllm_states = vllm_moe.forward(vllm_inputs)
......
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
DTYPES = [torch.float16, torch.bfloat16]
# m, n, k
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
SHAPES.extend(PAD_SHAPES)
SEEDS = [42]
CUDA_DEVICES = ['cuda:0']
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
kE2M1ToFloatArray = [
0.,
0.5,
1.,
1.5,
2.,
3.,
4.,
6.,
]
def e2m1_to_fp32(int4_value):
signBit = (int4_value & 0x8)
int4_absValue = int4_value & 0x7
float_result = kE2M1ToFloatArray[int4_absValue]
if (signBit):
float_result = -float_result
return float_result
def break_fp4_bytes(a, dtype):
assert (a.dtype == torch.uint8)
m, n = a.shape
a = a.flatten()
# Get upper 4 bits
highHalfByte = (a & 0xF0) >> 4
# Get lower 4 bits
lowHalfByte = a & 0x0F
fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device)
fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device)
# [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC]
out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2)
return out
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
sf_m, sf_k = a_sf_swizzled.shape
m_tiles = (m + 128 - 1) // 128
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
return out[0:m, 0:k]
def dequantize_to_dtype(tensor_fp4,
tensor_sf,
global_scale,
dtype,
device,
block_size=16):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out
def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale,
m, n, dtype, block_size, device):
_, m_k = a_fp4.shape
_, n_k = b_fp4.shape
assert (m_k == n_k)
a_in_dtype = dequantize_to_dtype(a_fp4,
a_sf,
a_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
b_in_dtype = dequantize_to_dtype(b_fp4,
b_sf,
b_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
return torch.matmul(a_in_dtype, b_in_dtype.t())
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_nvfp4_gemm(
dtype: torch.dtype,
shape: tuple[int, int, int],
seed: int,
device: str,
) -> None:
current_platform.seed_everything(seed)
m, n, packed_k = shape
k = packed_k * 2
block_size = 16
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
alpha = 1. / (a_global_scale * b_global_scale)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved,
b_scale_interleaved, a_global_scale,
b_global_scale, m, n, dtype, block_size,
device)
out = ops.cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved,
b_scale_interleaved, alpha, dtype)
torch.testing.assert_close(out,
expected_out.to(dtype=dtype),
atol=1e-1,
rtol=1e-1)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from itertools import accumulate, product from itertools import accumulate, product
from typing import Callable, Dict, List, Optional from typing import Callable, Optional
import pytest import pytest
import torch import torch
...@@ -179,7 +179,7 @@ def test_batched_rotary_embedding_multi_lora( ...@@ -179,7 +179,7 @@ def test_batched_rotary_embedding_multi_lora(
torch.set_default_device(device) torch.set_default_device(device)
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
scaling_factors: List[int] = [1, 2, 4] scaling_factors: list[int] = [1, 2, 4]
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"rope_type": "linear", "rope_type": "linear",
"factor": tuple(scaling_factors) "factor": tuple(scaling_factors)
...@@ -234,7 +234,7 @@ def test_rope_module_cache(): ...@@ -234,7 +234,7 @@ def test_rope_module_cache():
}) })
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE, settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
ROPE_SCALINGS, DTYPES) ROPE_SCALINGS, DTYPES)
rope_setting_id_map: Dict[str, int] = {} rope_setting_id_map: dict[str, int] = {}
for setting in product(*settings): for setting in product(*settings):
head_size, rotary_dim, max_position, base, \ head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting is_neox_stype, rope_scaling, dtype = setting
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment