Commit afd0da21 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.1' into v0.7.1-dev

parents 1a11f127 4f4d427a
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0230364128947258,
"1": 0.01979283057153225,
"2": 0.0241350457072258,
"3": 0.0308314748108387,
"4": 0.0430733822286129,
"5": 0.0370396226644516,
"6": 0.0306222103536129,
"7": 0.0357491634786129,
"8": 0.0358189195394516,
"9": 0.0443289652466774,
"10": 0.0433175228536129,
"11": 0.0416782945394516,
"12": 0.0366908498108387,
"13": 0.0432477705180645,
"14": 0.0410505048930645,
"15": 0.0457589291036129,
"16": 0.0418526791036129,
"17": 0.0432477705180645,
"18": 0.0469447560608387,
"19": 0.0514787957072258,
"20": 0.0541294664144516,
"21": 0.0587681382894516,
"22": 0.0625,
"23": 0.0585588738322258,
"24": 0.0600237175822258,
"25": 0.0588030144572258,
"26": 0.0531180277466774,
"27": 0.06396484375,
"28": 0.0603027381002903,
"29": 0.0582101047039032,
"30": 0.0625348836183548,
"31": 0.0585588738322258,
"32": 0.0582798570394516,
"33": 0.0575125589966774,
"34": 0.0590820349752903,
"35": 0.0614188089966774,
"36": 0.0631975457072258,
"37": 0.0615931935608387,
"38": 0.0601283498108387,
"39": 0.0571986623108387,
"40": 0.0670340433716774,
"41": 0.0523507259786129,
"42": 0.0547223798930645,
"43": 0.0631975457072258,
"44": 0.0663713738322258,
"45": 0.0603376142680645,
"46": 0.0652204304933548,
"47": 0.0734514519572258,
"48": 0.0693708211183548,
"49": 0.0725446492433548,
"50": 0.0627790242433548,
"51": 0.0691266804933548,
"52": 0.0688825398683548,
"53": 0.068429134786129,
"54": 0.0605119988322258,
"55": 0.0799386203289032,
"56": 0.0853097140789032,
"57": 0.0661969929933548,
"58": 0.0689871683716774,
"59": 0.0724051371216774,
"60": 0.0541643425822258,
"61": 0.0626743882894516,
"62": 0.0628487765789032,
"63": 0.0607212632894516,
"64": 0.0589076466858387,
"65": 0.0451660193502903,
"66": 0.0453055277466774,
"67": 0.0414341539144516,
"68": 0.0385044664144516,
"69": 0.0414341539144516,
"70": 0.0466308631002903,
"71": 0.0399693101644516,
"72": 0.0437011756002903,
"73": 0.0434221550822258,
"74": 0.0428989976644516,
"75": 0.0401785746216774,
"76": 0.0431082621216774,
"77": 0.0484444759786129,
"78": 0.0417829267680645,
"79": 0.0418178029358387
}
}
}
}
\ No newline at end of file
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0152239128947258,
"1": 0.0188860222697258,
"2": 0.0354178324341774,
"3": 0.0376674123108387,
"4": 0.0418526791036129,
"5": 0.0433175228536129,
"6": 0.0397600457072258,
"7": 0.0424455925822258,
"8": 0.0415387861430645,
"9": 0.0408412404358387,
"10": 0.0395856611430645,
"11": 0.0377371683716774,
"12": 0.0400739423930645,
"13": 0.040771484375,
"14": 0.0393415205180645,
"15": 0.0369001142680645,
"16": 0.03857421875,
"17": 0.0387486070394516,
"18": 0.0403180830180645,
"19": 0.0396205373108387,
"20": 0.0375627800822258,
"21": 0.0407366082072258,
"22": 0.0432477705180645,
"23": 0.0377022884786129,
"24": 0.0399693101644516,
"25": 0.0374581478536129,
"26": 0.0413295216858387,
"27": 0.0442243330180645,
"28": 0.0424804724752903,
"29": 0.0456891767680645,
"30": 0.0409109964966774,
"31": 0.0482352152466774
}
}
}
}
......@@ -6,8 +6,9 @@ import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
GeluAndMul, NewGELU,
QuickGELU, SiluAndMul)
GeluAndMul, MulAndSilu,
NewGELU, QuickGELU,
SiluAndMul)
from vllm.platforms import current_platform
from .allclose_default import get_default_atol, get_default_rtol
......@@ -21,8 +22,9 @@ CUDA_DEVICES = [
]
@pytest.mark.parametrize("activation",
["silu", "gelu", "gelu_tanh", "fatrelu"])
@pytest.mark.parametrize(
"activation",
["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
......@@ -40,9 +42,12 @@ def test_act_and_mul(
current_platform.seed_everything(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation == "silu":
if activation == "silu_and_mul":
layer = SiluAndMul()
fn = torch.ops._C.silu_and_mul
if activation == "mul_and_silu":
layer = MulAndSilu()
fn = torch.ops._C.mul_and_silu
elif activation == "gelu":
layer = GeluAndMul(approximate="none")
fn = torch.ops._C.gelu_and_mul
......@@ -55,8 +60,9 @@ def test_act_and_mul(
fn = torch.ops._C.fatrelu_and_mul
out = layer(x)
ref_out = layer.forward_native(x)
# The SiLU, GELU and FatReLU implementations are equivalent to the native
# PyTorch implementations, so we can do exact comparison.
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
# equivalent to the native PyTorch implementations, so we can do exact
# comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
d = x.shape[-1] // 2
......
......@@ -31,9 +31,9 @@ NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 120, 256]
# This should be sync with get_supported_head_sizes() in
# vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
......@@ -182,7 +182,7 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
k_scale = v_scale = 1.0
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Call the paged attention kernel.
output = torch.empty_like(query)
......
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
import torch
from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use
from vllm.platforms import cpu, cuda, openvino, rocm
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.openvino import OpenVinoPlatform
from vllm.platforms.rocm import RocmPlatform
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
from vllm.platforms import current_platform
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()
@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"] if not current_platform() else ["ROCM_FLASH"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
......@@ -21,71 +31,76 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable(monkeypatch, name)
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
cpu.CpuPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA"
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform",
rocm.RocmPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.current_platform",
openvino.OpenVinoPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "OPENVINO"
OpenVinoPlatform()), patch.dict('sys.modules',
{'openvino': Mock()}):
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "OPENVINO"
else:
with patch("vllm.attention.selector.current_platform",
cuda.CudaPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == name
if name in ["XFORMERS", "FLASHINFER"]:
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
16, False)
assert backend.get_name() == name
def test_flash_attn(monkeypatch):
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to
# which_attn_to_use
# get_attn_backend
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported data type
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported block size
backend = which_attn_to_use(16, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, None, 8, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported head size
backend = which_attn_to_use(17, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(17, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
assert backend.get_name() != STR_FLASH_ATTN_VAL
def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
"""Ignore the invalid env variable if it is set."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(16, torch.float16, None, 16, False)
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
backend = get_attn_backend(32, torch.float16, None, 16, False)
assert backend.get_name() == "FLASH_ATTN"
# when block size == 16, backend will fall back to XFORMERS
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() == "XFORMERS"
......@@ -92,8 +92,10 @@ def native_w8a8_block_fp8_matmul(A,
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
]
B_tiles = [[
B[j * block_n:min((j + 1) * block_n, N),
i * block_k:min((i + 1) * block_k, K), ] for i in range(k_tiles)
B[
j * block_n:min((j + 1) * block_n, N),
i * block_k:min((i + 1) * block_k, K),
] for i in range(k_tiles)
] for j in range(n_tiles)]
C_tiles = [
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
......@@ -157,9 +159,9 @@ def setup_cuda():
torch.set_default_device("cuda")
@pytest.mark.parametrize("num_tokens,d,dtype,group_size,seed",
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE,
SEEDS))
@pytest.mark.parametrize(
"num_tokens,d,dtype,group_size,seed",
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS))
@torch.inference_mode()
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
torch.manual_seed(seed)
......@@ -174,9 +176,9 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
assert torch.allclose(scale, ref_scale)
@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES,
SEEDS))
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@torch.inference_mode()
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
torch.manual_seed(seed)
......@@ -207,9 +209,10 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
assert rel_diff < 0.001
@pytest.mark.parametrize("M,N,K,E,topk,block_size,dtype,seed",
itertools.product(M_moe, N_moe, K_moe, E, TOP_KS,
BLOCK_SIZE, DTYPES, SEEDS))
@pytest.mark.parametrize(
"M,N,K,E,topk,block_size,dtype,seed",
itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES,
SEEDS))
@torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
torch.manual_seed(seed)
......
......@@ -210,7 +210,7 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
k_scale = v_scale = 1.0
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
tp_rank = 0
# Call the paged attention kernel.
......
......@@ -161,7 +161,7 @@ def test_reshape_and_cache(
cloned_value_cache = value_cache.clone()
# Using default kv_scale
k_scale = v_scale = 1.0
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
......@@ -259,8 +259,8 @@ def test_reshape_and_cache_flash(
del key_caches
del value_caches
k_scale = key.amax().item() / 256
v_scale = value.amax().item() / 256
k_scale = (key.amax() / 256.0).to(torch.float32)
v_scale = (value.amax() / 256.0).to(torch.float32)
# Clone the KV caches.
if kv_cache_dtype == "fp8":
......@@ -285,12 +285,12 @@ def test_reshape_and_cache_flash(
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache,
key_cache,
k_scale,
k_scale.item(),
kv_dtype=kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache,
value_cache,
v_scale,
v_scale.item(),
kv_dtype=kv_cache_dtype)
# Run the reference implementation.
......
from typing import List, Optional, Tuple
import pytest
import torch
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (cascade_attention,
merge_attn_states)
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
is_fa_version_supported)
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 192, 256]
BLOCK_SIZES = [16]
DTYPES = [torch.float16, torch.bfloat16]
@pytest.mark.parametrize("num_tokens", [1, 39, 16912])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_merge_kernel(
num_tokens: int,
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
current_platform.seed_everything(0)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
# Prepare inputs.
prefix_output = torch.randn(num_tokens,
num_query_heads,
head_size,
dtype=dtype)
suffix_output = torch.randn(num_tokens,
num_query_heads,
head_size,
dtype=dtype)
prefix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32)
suffix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32)
# Run the kernel.
output = torch.empty(num_tokens, num_query_heads, head_size, dtype=dtype)
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
suffix_lse)
# Reference implementation.
max_lse = torch.maximum(prefix_lse, suffix_lse)
p_lse = torch.exp(prefix_lse - max_lse)
s_lse = torch.exp(suffix_lse - max_lse)
p_scale = p_lse / (p_lse + s_lse)
s_scale = s_lse / (p_lse + s_lse)
p_scale = p_scale.transpose(0, 1).unsqueeze(2)
s_scale = s_scale.transpose(0, 1).unsqueeze(2)
ref_output = p_scale * prefix_output + s_scale * suffix_output
ref_output = ref_output.to(dtype)
# Compare the results.
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
CASES = [
# Case 1. A general case.
([(129, 871), (18, 280), (37, 988), (1023, 2304), (1, 257)], 256),
# Case 2. Flash-decoding case.
([(1, 1023), (1, 879), (1, 778), (1, 1777)] * 100, 512),
]
@pytest.mark.parametrize("seq_lens_and_common_prefix", CASES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("soft_cap", [None, 50])
@pytest.mark.parametrize("num_blocks", [2048])
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode()
def test_cascade(
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
fa_version: int,
) -> None:
torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
current_platform.seed_everything(0)
window_size = (-1, -1)
scale = head_size**-0.5
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
seq_lens, common_prefix_len = seq_lens_and_common_prefix
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
total_num_query_tokens = sum(query_lens)
query = torch.randn(total_num_query_tokens,
num_query_heads,
head_size,
dtype=dtype)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
assert common_prefix_len > 0
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size
# Make sure the first `num_common_kv_blocks` blocks are the same.
block_tables[:, :num_common_kv_blocks] = \
block_tables[0, :num_common_kv_blocks]
# Run the regular attention.
ref_output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
seqused_k=kv_lens_tensor,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
)
# Run cascade attention.
assert all(common_prefix_len < kv_len for kv_len in kv_lens)
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens],
dtype=torch.int32)
prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32)
suffix_kv_lens = kv_lens_tensor - common_prefix_len
output = torch.empty_like(query)
cascade_attention(
output=output,
query=query,
key_cache=key_cache,
value_cache=value_cache,
cu_query_lens=cu_query_lens,
max_query_len=max_query_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
max_kv_len=max_kv_len,
softmax_scale=scale,
alibi_slopes=None,
sliding_window=window_size,
logits_soft_cap=soft_cap if soft_cap is not None else 0,
block_table=block_tables,
common_prefix_len=common_prefix_len,
fa_version=fa_version,
)
# Compare the results.
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
......@@ -2,7 +2,7 @@
Run `pytest tests/kernels/test_cutlass.py`.
"""
from typing import Optional, Type
from typing import Type, Optional
import pytest
import torch
......@@ -10,6 +10,9 @@ import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import cdiv
from .utils import baseline_scaled_mm, to_fp8, to_int8
MNK_FACTORS = [
(1, 256, 128),
......@@ -37,20 +40,15 @@ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
# -1 means full extent in that dimension
TENSORWISE_GROUP_SHAPE = (-1, -1)
PER_TOKEN_GROUP_SHAPE = (1, -1)
PER_OUT_CH_GROUP_SHAPE = (-1, 1)
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def rand_int8(shape: tuple, device: str = "cuda"):
return to_int8(torch.rand(shape, device=device) * 255 - 128)
......@@ -66,14 +64,22 @@ def baseline_scaled_mm(a: torch.Tensor,
if bias is not None:
output = output + bias
return output
def group_scale_helper(shape, group_shape):
return [shape[i] if s < 0 else s for i, s in enumerate(group_shape)]
def scale_shape(shape, group_shape):
assert len(shape) == len(group_shape)
group_shape = group_scale_helper(shape, group_shape)
return tuple(
cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
def cutlass_fp8_gemm_helper(m: int,
n: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
......@@ -82,13 +88,17 @@ def cutlass_fp8_gemm_helper(m: int,
a = to_fp8(torch.randn((m, k), device=device))
b = to_fp8(torch.randn((n, k), device=device).t())
m_a_scales = m if per_token_act_quant else 1
n_b_scales = n if per_out_channel_weight_quant else 1
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
# make scales M-major for blockwise quant, doesn't affect 1D scales
scale_a = scale_a.t().contiguous().t()
# make scales K-major for blockwise quant, doesn't affect 1D scales
scale_b = scale_b.t().contiguous().t()
scale_a = (torch.randn((m_a_scales, 1), device=device,
dtype=torch.float32))
scale_b = (torch.randn((1, n_b_scales), device=device,
dtype=torch.float32))
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
else:
......@@ -106,8 +116,8 @@ def cutlass_fp8_gemm_helper(m: int,
def cutlass_int8_gemm_helper(m: int,
n: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
......@@ -116,13 +126,11 @@ def cutlass_int8_gemm_helper(m: int,
a = to_int8(torch.randn((m, k), device=device) * 5)
b = to_int8(torch.randn((n, k), device=device).t() * 5)
m_a_scales = m if per_token_act_quant else 1
n_b_scales = n if per_out_channel_weight_quant else 1
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
scale_a = (torch.randn((m_a_scales, 1), device=device,
dtype=torch.float32))
scale_b = (torch.randn((1, n_b_scales), device=device,
dtype=torch.float32))
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
......@@ -139,85 +147,139 @@ def cutlass_int8_gemm_helper(m: int,
# @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
# @pytest.mark.parametrize("per_act_token", [True, False])
# @pytest.mark.parametrize("per_out_ch", [True, False])
# @pytest.mark.parametrize("a_scale_group_shape",
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
# reason="FP8 is not supported on this GPU type.")
# def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
# per_out_ch: bool, use_bias: bool):
# cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
# def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape,
# b_scale_group_shape, use_bias: bool):
# cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
# use_bias)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
[((1, 128), (128, 128))])
@pytest.mark.parametrize("use_bias", [False])
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="FP8 blockwise is not supported on this GPU type.")
def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
a_scale_group_shape,
b_scale_group_shape, use_bias: bool):
if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0:
return
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
return
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool, use_bias: bool):
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
b_scale_group_shape, use_bias: bool):
cutlass_int8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias)
@pytest.mark.parametrize("per_act_token", [True])
@pytest.mark.parametrize("per_out_ch", [True])
@pytest.mark.parametrize("out_dtype", [ torch.float16]) #torch.bfloat16,
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: Type[torch.dtype],
use_bias: bool):
cutlass_int8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype)
# @pytest.mark.parametrize("per_act_token", [True, False])
# @pytest.mark.parametrize("per_out_ch", [True, False])
# @pytest.mark.parametrize("a_scale_group_shape",
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
# reason="FP8 is not supported on this GPU type.")
# def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
# def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
# b_scale_group_shape,
# out_dtype: Type[torch.dtype],
# use_bias: bool):
# cutlass_fp8_gemm_helper(512,
# 512,
# 512,
# per_act_token,
# per_out_ch,
# a_scale_group_shape,
# b_scale_group_shape,
# use_bias,
# out_dtype=out_dtype)
# @pytest.mark.parametrize("per_act_token", [True, False])
# @pytest.mark.parametrize("per_out_ch", [True, False])
# @pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
# [((1, 128), (128, 128))])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
# @pytest.mark.parametrize("use_bias", [False])
# @pytest.mark.skipif(not current_platform.has_device_capability(90),
# reason="FP8 blockwise is not supported on this GPU type.")
# def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
# b_scale_group_shape,
# out_dtype: Type[torch.dtype],
# use_bias: bool):
# cutlass_fp8_gemm_helper(512,
# 512,
# 512,
# a_scale_group_shape,
# b_scale_group_shape,
# use_bias,
# out_dtype=out_dtype)
# @pytest.mark.parametrize("a_scale_group_shape",
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
# reason="FP8 is not supported on this GPU type.")
# def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# def test_cutlass_fp8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
# use_bias: bool, device: str):
# cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
# torch.bfloat16, device)
# cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape,
# b_scale_group_shape, use_bias, torch.bfloat16,
# device)
# @pytest.mark.parametrize("per_act_token", [True, False])
# @pytest.mark.parametrize("per_out_ch", [True, False])
# @pytest.mark.parametrize("a_scale_group_shape",
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# use_bias: bool, device: str):
# cutlass_int8_gemm_helper(512,
# 512,
# 512,
# per_act_token,
# per_out_ch,
# use_bias,
# out_dtype=torch.bfloat16,
# device=device)
# def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
use_bias: bool, device: str):
cutlass_int8_gemm_helper(512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=torch.bfloat16,
device=device)
# For the following two tests:
......@@ -225,28 +287,32 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
# @pytest.mark.parametrize("per_act_token", [True, False])
# @pytest.mark.parametrize("per_out_ch", [True, False])
# @pytest.mark.parametrize("a_scale_group_shape",
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
# reason="FP8 is not supported on this GPU type.")
# def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
# def test_cutlass_fp8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
# use_bias: bool):
# for nk in range(32, 128, 32):
# for m in range(1, 128):
# cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
# use_bias)
# cutlass_fp8_gemm_helper(m, nk, nk, a_scale_group_shape,
# b_scale_group_shape, use_bias)
# @pytest.mark.parametrize("per_act_token", [True, False])
# @pytest.mark.parametrize("per_out_ch", [True, False])
# @pytest.mark.parametrize("a_scale_group_shape",
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False])
# def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
# use_bias: bool):
# for nk in range(32, 128, 32):
# for m in range(1, 128):
# cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
# use_bias)
# def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
use_bias: bool):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_int8_gemm_helper(m, nk, nk, a_scale_group_shape,
b_scale_group_shape, use_bias)
# @pytest.mark.parametrize("m", [32, 64, 128])
......@@ -304,38 +370,39 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
# @pytest.mark.parametrize("n", [16, 32, 64])
# @pytest.mark.parametrize("k", [64, 128, 256])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
# @pytest.mark.skip
# def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
# out_dtype: torch.dtype):
# # Currently, the test is failing because folding azp into
# # 16-bit bias loses too much precision
# scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
# scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
# aq_i8 = rand_int8((m, k))
# bq_i8 = rand_int8((n, k)).t()
# aq_i32 = aq_i8.to(dtype=torch.int32)
# bq_i32 = bq_i8.to(dtype=torch.int32)
# aq_f32 = aq_i8.to(dtype=torch.float32)
# bq_f32 = bq_i8.to(dtype=torch.float32)
# b_dq = scale_b * bq_f32
# azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
# azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
# azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
# a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
# torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
# baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("azp_per_token", [True, False])
# def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
# use_bias: bool, azp_per_token: bool):
# m_azp = m if azp_per_token else 1
# scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
# scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
# aq_i8 = rand_int8((m, k))
# aq_i32 = aq_i8.to(dtype=torch.int32)
# aq_f32 = aq_i8.to(dtype=torch.float32)
# bq_i8 = rand_int8((n, k)).t()
# bq_i32 = bq_i8.to(dtype=torch.int32)
# bq_f32 = bq_i8.to(dtype=torch.float32)
# b_dq = scale_b * bq_f32
# azp_a = torch.rand(
# (m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
# azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
# azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
# a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
# torch.testing.assert_close(a_dq,
# scale_a * aq_f32 - azp_a,
# rtol=1e-4,
# atol=1e-3)
# if use_bias:
# bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
# else:
# bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
# J = torch.ones((1, k), device="cuda", dtype=torch.float32)
# azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
# assert azp_bias.shape == (1, n)
# assert azp_bias[0, :].shape == (n, )
# baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
# (aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
......
......@@ -2,16 +2,19 @@
Run `pytest tests/kernels/test_semi_structured.py`.
"""
from typing import Optional, Tuple, Type
from typing import Tuple, Type
import pytest
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported)
from vllm.platforms import current_platform
from .utils import baseline_scaled_mm, to_fp8, to_int8
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
......@@ -20,20 +23,6 @@ capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def rand_int8(shape: tuple, device: str = "cuda"):
return to_int8(torch.rand(shape, device=device) * 255 - 128)
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.bfloat16)
......@@ -90,22 +79,8 @@ def make_rand_sparse_tensors(
return b_compressed, e, a, b
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = (scale_a * (scale_b * (torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
if bias is not None:
output = output + bias
return output
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.")
reason="Sparse CUTLASS is not supported on this GPU type.")
# Test working with a subset of A and B for sparse matmul
def test_cutlass_sparse_subset():
......@@ -132,3 +107,108 @@ def test_cutlass_sparse_subset():
out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
MNK_FACTORS = [
(1, 256, 128),
(1, 16384, 1024),
(1, 24576, 512),
(16, 256, 512),
(16, 16384, 128),
(16, 24576, 4096),
(32, 8192, 4096),
(32, 16384, 4096),
(33, 1024, 1024),
(33, 8192, 128),
(64, 2048, 512),
(64, 16384, 1024),
(100, 8192, 512),
(128, 32768, 4096),
(256, 4096, 4096),
(512, 256, 1024),
(512, 8192, 4096),
(512, 16384, 128),
(512, 24576, 128),
]
# Test working with a subset of A and B for sparse matmul
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype]):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=dtype)
baseline = F.linear(a, b.T)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1e-2)
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool, use_bias: bool):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
......@@ -13,8 +13,7 @@ import pytest
import torch
from tests.kernels.utils import *
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
AttentionType)
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
......@@ -64,6 +63,7 @@ class TestPoint(NamedTuple):
max_dec_seq_len: int
max_enc_seq_len: int
num_blocks: int
attn_type: AttentionType
class TestResources(NamedTuple):
......@@ -96,7 +96,6 @@ class TestResources(NamedTuple):
'''
scale: float
attn_backend: AttentionBackend
attn: Attention
kv_cache: torch.Tensor
......@@ -129,26 +128,33 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
'''
scale = float(1.0 / (test_pt.head_size**0.5))
attn_backend = make_backend(test_pt.backend_name)
attn = Attention(
test_pt.num_heads,
test_pt.head_size,
scale=scale,
prefix=f"{test_pt.attn_type}",
attn_type=test_pt.attn_type,
)
if test_pt.num_blocks is None or test_pt.num_heads is None:
# Caller does not require a KV cache
return TestResources(
scale, attn_backend, attn,
scale, attn,
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
# Construct KV cache
kv_cache = make_kv_cache(test_pt.num_blocks,
test_pt.num_heads,
test_pt.head_size,
test_pt.block_size,
device=CUDA_DEVICE,
backend=test_pt.backend_name)
return TestResources(scale, attn_backend, attn, kv_cache)
if test_pt.attn_type in (AttentionType.DECODER,
AttentionType.ENCODER_DECODER):
kv_cache = make_kv_cache(test_pt.num_blocks,
test_pt.num_heads,
test_pt.head_size,
test_pt.block_size,
device=CUDA_DEVICE,
backend=test_pt.backend_name)
else:
kv_cache = torch.tensor([])
attn.kv_cache = [kv_cache]
return TestResources(scale, attn, kv_cache)
def _encoder_attn_setup(
......@@ -193,6 +199,7 @@ def _encoder_attn_setup(
_,
max_q_seq_len,
_,
_,
) = test_pt
scale = test_rsrcs.scale
......@@ -301,6 +308,7 @@ def _decoder_attn_setup(
max_q_seq_len,
_,
_,
_,
) = test_pt
scale = test_rsrcs.scale
......@@ -488,6 +496,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
max_decoder_seq_len,
max_encoder_seq_len,
_,
_,
) = test_pt
scale = test_rsrcs.scale
......@@ -622,7 +631,6 @@ def _run_encoder_attention_test(
& attn_metadata
'''
assert attn_metadata.num_decode_tokens == 0
attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None
with set_forward_context(attn_metadata, vllm_config):
......@@ -635,14 +643,11 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
packed_qkv.key,
packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device),
attn_metadata,
attn_type=attn_type)
return attn.forward(
reshaped_query, packed_qkv.key, packed_qkv.value,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device), attn_metadata)
def _run_decoder_self_attention_test(
......@@ -675,7 +680,6 @@ def _run_decoder_self_attention_test(
* Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata
'''
attn_type = AttentionType.DECODER
attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
......@@ -690,12 +694,8 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
packed_qkv.key,
packed_qkv.value,
kv_cache,
attn_metadata,
attn_type=attn_type)
return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value,
kv_cache, attn_metadata)
def _run_encoder_decoder_cross_attention_test(
......@@ -742,7 +742,6 @@ def _run_encoder_decoder_cross_attention_test(
'''
assert decoder_test_params.packed_qkvo.packed_qkv is not None
attn_type = AttentionType.ENCODER_DECODER
attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache
if cross_test_params is None:
......@@ -762,12 +761,8 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query,
key,
value,
kv_cache,
attn_metadata,
attn_type=attn_type)
return attn.forward(reshaped_query, key, value, kv_cache,
attn_metadata)
@pytest.fixture(autouse=True)
......@@ -839,7 +834,7 @@ def test_encoder_only(
# is not part of this test
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096)
max_enc_seq_len, 4096, AttentionType.ENCODER)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
......@@ -855,7 +850,7 @@ def test_encoder_only(
# Shared prefill metadata structure
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
attn_backend,
True,
None,
decoder_test_params=None,
......@@ -961,20 +956,29 @@ def test_e2e_enc_dec_attn(
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096)
enc_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096, AttentionType.ENCODER)
enc_dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096,
AttentionType.ENCODER_DECODER)
dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096, AttentionType.DECODER)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt)
enc_test_rsrcs = _make_test_resources(enc_test_pt)
enc_dec_test_rsrcs = _make_test_resources(enc_dec_test_pt)
dec_test_rsrcs = _make_test_resources(dec_test_pt)
# Construct encoder attention test params (only used
# during prefill)
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
enc_test_params = _encoder_attn_setup(enc_test_pt, enc_test_rsrcs)
# Construct Decoder self-attention prefill-phase & decode-phase
# test params, including query/key/value tensors, decoder self-attention
......@@ -987,7 +991,7 @@ def test_e2e_enc_dec_attn(
prephase_dec_test_params,
decphase_dec_test_params,
cross_block_base_addr,
) = _decoder_attn_setup(test_pt, test_rsrcs)
) = _decoder_attn_setup(dec_test_pt, dec_test_rsrcs)
# Construct encoder/decoder cross-attention prefill-phase
# & decode-phase test params, including key/value tensors,
......@@ -1000,14 +1004,14 @@ def test_e2e_enc_dec_attn(
dec_qkv,
enc_test_params,
prephase_dec_test_params,
test_pt,
test_rsrcs,
enc_dec_test_pt,
enc_dec_test_rsrcs,
block_base_addr=cross_block_base_addr)
# Shared prefill metadata structure
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
attn_backend,
True,
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
decoder_test_params=prephase_dec_test_params,
......@@ -1017,10 +1021,10 @@ def test_e2e_enc_dec_attn(
# PREFILL: encoder attention
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
enc_pckd_act_out = _run_encoder_attention_test(enc_test_rsrcs.attn,
enc_test_params,
prephase_attn_metadata,
test_pt=test_pt,
test_pt=enc_test_pt,
vllm_config=vllm_config)
# - Is encoder attention result correct?
......@@ -1030,10 +1034,10 @@ def test_e2e_enc_dec_attn(
# PREFILL: decoder self-attention test
prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs,
dec_test_rsrcs,
prephase_dec_test_params,
prephase_attn_metadata,
test_pt=test_pt,
test_pt=dec_test_pt,
vllm_config=vllm_config)
# - Is prefill decoder self-attention correct?
......@@ -1044,11 +1048,11 @@ def test_e2e_enc_dec_attn(
# PREFILL: encoder/decoder cross-attention test
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs,
enc_dec_test_rsrcs,
prephase_dec_test_params,
prephase_cross_test_params,
prephase_attn_metadata,
test_pt=test_pt,
test_pt=enc_dec_test_pt,
vllm_config=vllm_config)
# - Is prefill encoder/decoder cross-attention correct?
......@@ -1059,7 +1063,7 @@ def test_e2e_enc_dec_attn(
# DECODE: build decode-phase attention metadata
decphase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend,
attn_backend,
False,
dec_qkv.q_seq_lens,
decoder_test_params=decphase_dec_test_params,
......@@ -1070,10 +1074,10 @@ def test_e2e_enc_dec_attn(
# DECODE: decoder self-attention test
decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs,
dec_test_rsrcs,
decphase_dec_test_params,
decphase_attn_metadata,
test_pt=test_pt,
test_pt=dec_test_pt,
vllm_config=vllm_config)
# - Is decode-phase decoder self-attention correct?
......@@ -1084,11 +1088,11 @@ def test_e2e_enc_dec_attn(
# DECODE: encoder/decoder cross-attention test
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs,
enc_dec_test_rsrcs,
decphase_dec_test_params,
None,
decphase_attn_metadata,
test_pt=test_pt,
test_pt=enc_dec_test_pt,
vllm_config=vllm_config)
# - Is decode-phase encoder/decoder cross-attention correct?
......
......@@ -5,11 +5,14 @@ import torch
from vllm.platforms import current_platform
if current_platform():
import flash_attn
else:
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
......@@ -84,6 +87,7 @@ if not current_platform():
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
use_out: bool,
......@@ -95,8 +99,13 @@ if not current_platform():
soft_cap: Optional[float],
num_blocks: int,
sliding_window: Optional[int],
fa_version: int,
) -> None:
torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
current_platform.seed_everything(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
......@@ -135,6 +144,7 @@ if not current_platform():
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size,
fa_version=fa_version,
)
output = output if not use_out else out
output = output.squeeze(1)
......@@ -150,9 +160,8 @@ if not current_platform():
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
@pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("seq_lens",
[[(1, 1328), (5, 18),
......@@ -164,6 +173,7 @@ if not current_platform():
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode()
def test_varlen_with_paged_kv(
use_out: bool,
......@@ -175,8 +185,12 @@ def test_varlen_with_paged_kv(
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
fa_version: int,
) -> None:
torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
......@@ -206,6 +220,7 @@ def test_varlen_with_paged_kv(
cu_kv_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
......@@ -230,6 +245,7 @@ def test_varlen_with_paged_kv(
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
# fa_version=fa_version,
)
else:
output = flash_attn_varlen_func(
......@@ -238,7 +254,7 @@ def test_varlen_with_paged_kv(
v=value_cache,
out=out,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
seqused_k=kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
......@@ -246,7 +262,9 @@ def test_varlen_with_paged_kv(
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version,
)
output = output if not use_out else out
ref_output = ref_paged_attn(
......
"""
Test:
* Tests for MultiHeadAttention layer
"""
from unittest.mock import patch
import pytest
import torch
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.selector import _Backend, _cached_get_attn_backend
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_mha_attn_platform(device: str):
"""
Test the attention selector between different platform and device.
"""
torch.set_default_dtype(torch.float16)
if device == "cpu":
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
elif device == "hip":
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
else:
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.XFORMERS
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.XFORMERS
def ref_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
"""
Native implementation of scaled dot product attention without mask:
- query, key, value: [batch_size, seq_len, num_heads, head_size]
- attn_mask: [batch_size, seq_len, seq_len]
"""
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.matmul(attn_weights, value).transpose(1, 2)
return out
BATCH_SIZES = [1, 16]
SEQ_LENS = [1]
NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80]
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [
torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
CUDA_DEVICES = ["cuda"]
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_forward(
batch_size: int,
seq_len: int,
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: str,
):
current_platform.seed_everything(0)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
q = torch.randn(batch_size, seq_len, num_heads * head_size)
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
scale = 1.0 / head_size**0.5
attn = MultiHeadAttention(num_heads,
head_size,
scale=scale,
num_kv_heads=num_kv_heads)
output = attn(q, k, v)
assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
q = q.reshape(batch_size, seq_len, num_heads, head_size)
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
if num_queries_per_kv > 1:
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
ref_output = ref_attention(
q,
k,
v,
scale=scale,
).reshape(batch_size, seq_len, num_heads * head_size)
torch.testing.assert_close(output, ref_output)
......@@ -14,8 +14,12 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights)
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
......@@ -46,8 +50,102 @@ def test_fused_moe(
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
torch.testing.assert_close(iterative_output,
torch_output,
atol=2e-2,
rtol=0)
@pytest.mark.parametrize("m", [1, 32, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype, group_size: int, has_zp: bool,
weight_bits: int):
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
if weight_bits == 4:
pack_factor = 2
quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
elif weight_bits == 8:
pack_factor = 1
quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128
w1_ref = w1.clone()
w2_ref = w2.clone()
w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
device="cuda",
dtype=torch.uint8)
w2_qweight = torch.empty((e, k, n // pack_factor),
device="cuda",
dtype=torch.uint8)
w1_scales = torch.empty((e, 2 * n, k // group_size),
device="cuda",
dtype=dtype)
w2_scales = torch.empty((e, k, n // group_size),
device="cuda",
dtype=dtype)
w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
device="cuda",
dtype=torch.uint8)
w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
device="cuda",
dtype=torch.uint8)
for i in range(e * 2):
expert_id = i % e
if i // e == 0:
w, w_ref, w_qweight, w_scales, w_qzeros = \
w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
else:
w, w_ref, w_qweight, w_scales, w_qzeros = \
w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
weight, qweight, scales, qzeros = quantize_weights(
w[expert_id].T, quant_type, group_size, has_zp, False)
weight = weight.T
qweight = qweight.T.contiguous().to(torch.uint8)
scales = scales.T
if has_zp:
qzeros = qzeros.T.contiguous().to(torch.uint8)
if weight_bits == 4:
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
if has_zp:
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
w_ref[expert_id] = weight
w_qweight[expert_id] = qweight
w_scales[expert_id] = scales
if has_zp:
w_qzeros[expert_id] = qzeros
triton_output = fused_moe(a,
w1_qweight,
w2_qweight,
score,
topk,
renormalize=False,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@torch.inference_mode()
......
......@@ -140,6 +140,7 @@ def test_contexted_kv_attention(
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous()
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
......@@ -155,6 +156,8 @@ def test_contexted_kv_attention(
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
torch.cuda.synchronize()
start_time = time.time()
......@@ -170,6 +173,8 @@ def test_contexted_kv_attention(
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
torch.cuda.synchronize()
end_time = time.time()
......@@ -369,6 +374,7 @@ def test_contexted_kv_attention_alibi(
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous()
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
......@@ -384,6 +390,8 @@ def test_contexted_kv_attention_alibi(
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
torch.cuda.synchronize()
start_time = time.time()
......@@ -399,6 +407,8 @@ def test_contexted_kv_attention_alibi(
b_seq_len,
b_ctx_len,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
torch.cuda.synchronize()
end_time = time.time()
......
import pytest
import torch
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
def cdiv(a, b):
return (a + b - 1) // b
@pytest.mark.parametrize("B", [3, 5])
@pytest.mark.parametrize("L", [1027, 1025])
@pytest.mark.parametrize("H_Q", [32])
@pytest.mark.parametrize("H_KV", [32, 8])
@pytest.mark.parametrize("D_QK", [128, 192, 576])
@pytest.mark.parametrize("D_V", [128, 512])
@pytest.mark.parametrize("CACHE_SIZE", [16384])
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
assert CACHE_SIZE % PAGE_SIZE == 0
dtype = torch.bfloat16
seq_len = L # This represents the number of tokens already in the sequence
sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 8
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
req_to_page = torch.randint(0,
CACHE_SIZE // PAGE_SIZE,
(B, num_pages_per_batch, 1),
device="cuda")
req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
1, 1, -1)
req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous()
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
# Call the original implementation.
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
# Page size can be larger than 1.
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
o1 = torch.zeros_like(o)
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o1)
......@@ -39,6 +39,23 @@ def get_8bit_types():
return types
# This test is to check regressions for int8 support on ROCm.
@pytest.mark.parametrize("model_path", [
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="Should only run on ROCm")
def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
max_tokens, num_logprobs):
dtype = "bfloat16"
with vllm_runner(model_path, dtype=dtype) as vllm_model:
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens,
num_logprobs)
@pytest.mark.parametrize("M", [1, 33, 64, 512])
@pytest.mark.parametrize("N", [256, 971, 20486])
@pytest.mark.parametrize("K", [128, 496, 1024])
......
......@@ -133,17 +133,19 @@ def test_flashinfer_decode_with_paged_kv(
use_tensor_cores=(
(num_query_heads//num_kv_heads) > 4)
)
wrapper.begin_forward(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
data_type=dtype)
output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap)
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap)
output = wrapper.run(query, key_value_cache)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
......@@ -228,7 +230,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD")
wrapper.begin_forward(
wrapper.plan(
qo_indptr,
kv_indptr,
kv_indices,
......@@ -237,12 +239,14 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
num_kv_heads,
head_size,
block_size,
q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap,
)
output = wrapper.forward(
output = wrapper.run(
query,
key_value_cache,
logits_soft_cap=soft_cap,
)
ref_output = ref_paged_attn(query=query,
......@@ -253,7 +257,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
......@@ -332,7 +336,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD")
wrapper.begin_forward(
wrapper.plan(
qo_indptr,
kv_indptr,
kv_indices,
......@@ -341,13 +345,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
num_kv_heads,
head_size,
block_size,
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
logits_soft_cap=soft_cap,
)
output = wrapper.forward(query,
kv_cache_fp8,
logits_soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale)
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache.squeeze(1),
......@@ -360,7 +363,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
del query
del block_tables
# verify prefill fp8
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
......@@ -439,21 +442,18 @@ def test_flashinfer_decode_with_paged_fp8_kv(
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=use_tensor_cores)
wrapper.begin_forward(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
data_type=dtype,
q_data_type=dtype)
output = wrapper.forward(query,
kv_cache_fp8,
logits_soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale)
wrapper.plan(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
logits_soft_cap=soft_cap)
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
......
......@@ -5,7 +5,7 @@ import random
import unittest
from numbers import Number
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
Union)
Type, Union)
import pytest
import torch
......@@ -13,6 +13,7 @@ from torch._prims_common import TensorLikeType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
......@@ -790,7 +791,7 @@ def make_block_tables_slot_mapping(
def make_test_metadata(
attn_backend: AttentionBackend,
attn_backend: _Backend,
is_prompt: bool,
seq_lens: Optional[List[int]],
decoder_test_params: Optional[PhaseTestParameters],
......@@ -815,7 +816,7 @@ def make_test_metadata(
Arguments:
* attn_backend: Backend for sourcing attention kernels
* attn_backend_name: Backend for sourcing attention kernels
* is_prompt: prefill if True, o/w decode
* seq_lens: list of token counts for each sequence
* decoder_test_params: decoder self-attention test params;
......@@ -882,6 +883,8 @@ def make_test_metadata(
# (kv_mmap)
cross_kv_mmap = cross_test_params.kv_mmap
attn_backend_obj = make_backend(attn_backend.name)
if is_prompt:
# Prefill-phase scenario
......@@ -902,11 +905,11 @@ def make_test_metadata(
context_lens,
encoder_seq_lens,
device=device)
return attn_backend.make_metadata(
return attn_backend_obj.make_metadata(
num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
......@@ -952,10 +955,11 @@ def make_test_metadata(
encoder_seq_lens,
device=device)
return attn_backend.make_metadata(
return attn_backend_obj.make_metadata(
num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
......@@ -1096,3 +1100,56 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
kwargs,
test_utils=test_utils,
raise_exception=raise_exception) if cond else {}
# For testing quantized linear kernels
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match the
# the target shape by repeating the data along that dimension (broadcasting)
# , we extend these semantics to say if the extent of a dimension in the
# source shape is not 1 and does not match the target shape we repeat each
# element along that dimension src_shape[dim] // target_shape[dim] times
# example if we have:
# a = [[1, 2], and target_shape = (2, 4)
# [3, 4]]
# then we would expand a to:
# a = [[1, 1, 2, 2],
# [3, 3, 4, 4]]
# NOTE this function this function does not explicitly broadcast dimensions
# with an extent of 1, since this can be done implicitly by pytorch
def group_broadcast(t, shape):
for i, s in enumerate(shape):
if t.shape[i] != s and t.shape[i] != 1:
assert s % t.shape[i] == 0
t = t.unsqueeze(i + 1)\
.expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
.flatten(i, i + 1)
return t
scale_a = group_broadcast(scale_a, a.shape)
scale_b = group_broadcast(scale_b, b.shape)
output = torch.mm((scale_a * a.to(dtype=torch.float32)),
(scale_b * b.to(dtype=torch.float32))).to(out_dtype)
if bias is not None:
output = output + bias
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment