Commit afd0da21 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.1' into v0.7.1-dev

parents 1a11f127 4f4d427a
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0230364128947258,
"1": 0.01979283057153225,
"2": 0.0241350457072258,
"3": 0.0308314748108387,
"4": 0.0430733822286129,
"5": 0.0370396226644516,
"6": 0.0306222103536129,
"7": 0.0357491634786129,
"8": 0.0358189195394516,
"9": 0.0443289652466774,
"10": 0.0433175228536129,
"11": 0.0416782945394516,
"12": 0.0366908498108387,
"13": 0.0432477705180645,
"14": 0.0410505048930645,
"15": 0.0457589291036129,
"16": 0.0418526791036129,
"17": 0.0432477705180645,
"18": 0.0469447560608387,
"19": 0.0514787957072258,
"20": 0.0541294664144516,
"21": 0.0587681382894516,
"22": 0.0625,
"23": 0.0585588738322258,
"24": 0.0600237175822258,
"25": 0.0588030144572258,
"26": 0.0531180277466774,
"27": 0.06396484375,
"28": 0.0603027381002903,
"29": 0.0582101047039032,
"30": 0.0625348836183548,
"31": 0.0585588738322258,
"32": 0.0582798570394516,
"33": 0.0575125589966774,
"34": 0.0590820349752903,
"35": 0.0614188089966774,
"36": 0.0631975457072258,
"37": 0.0615931935608387,
"38": 0.0601283498108387,
"39": 0.0571986623108387,
"40": 0.0670340433716774,
"41": 0.0523507259786129,
"42": 0.0547223798930645,
"43": 0.0631975457072258,
"44": 0.0663713738322258,
"45": 0.0603376142680645,
"46": 0.0652204304933548,
"47": 0.0734514519572258,
"48": 0.0693708211183548,
"49": 0.0725446492433548,
"50": 0.0627790242433548,
"51": 0.0691266804933548,
"52": 0.0688825398683548,
"53": 0.068429134786129,
"54": 0.0605119988322258,
"55": 0.0799386203289032,
"56": 0.0853097140789032,
"57": 0.0661969929933548,
"58": 0.0689871683716774,
"59": 0.0724051371216774,
"60": 0.0541643425822258,
"61": 0.0626743882894516,
"62": 0.0628487765789032,
"63": 0.0607212632894516,
"64": 0.0589076466858387,
"65": 0.0451660193502903,
"66": 0.0453055277466774,
"67": 0.0414341539144516,
"68": 0.0385044664144516,
"69": 0.0414341539144516,
"70": 0.0466308631002903,
"71": 0.0399693101644516,
"72": 0.0437011756002903,
"73": 0.0434221550822258,
"74": 0.0428989976644516,
"75": 0.0401785746216774,
"76": 0.0431082621216774,
"77": 0.0484444759786129,
"78": 0.0417829267680645,
"79": 0.0418178029358387
}
}
}
}
\ No newline at end of file
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0152239128947258,
"1": 0.0188860222697258,
"2": 0.0354178324341774,
"3": 0.0376674123108387,
"4": 0.0418526791036129,
"5": 0.0433175228536129,
"6": 0.0397600457072258,
"7": 0.0424455925822258,
"8": 0.0415387861430645,
"9": 0.0408412404358387,
"10": 0.0395856611430645,
"11": 0.0377371683716774,
"12": 0.0400739423930645,
"13": 0.040771484375,
"14": 0.0393415205180645,
"15": 0.0369001142680645,
"16": 0.03857421875,
"17": 0.0387486070394516,
"18": 0.0403180830180645,
"19": 0.0396205373108387,
"20": 0.0375627800822258,
"21": 0.0407366082072258,
"22": 0.0432477705180645,
"23": 0.0377022884786129,
"24": 0.0399693101644516,
"25": 0.0374581478536129,
"26": 0.0413295216858387,
"27": 0.0442243330180645,
"28": 0.0424804724752903,
"29": 0.0456891767680645,
"30": 0.0409109964966774,
"31": 0.0482352152466774
}
}
}
}
...@@ -6,8 +6,9 @@ import torch ...@@ -6,8 +6,9 @@ import torch
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
GeluAndMul, NewGELU, GeluAndMul, MulAndSilu,
QuickGELU, SiluAndMul) NewGELU, QuickGELU,
SiluAndMul)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .allclose_default import get_default_atol, get_default_rtol from .allclose_default import get_default_atol, get_default_rtol
...@@ -21,8 +22,9 @@ CUDA_DEVICES = [ ...@@ -21,8 +22,9 @@ CUDA_DEVICES = [
] ]
@pytest.mark.parametrize("activation", @pytest.mark.parametrize(
["silu", "gelu", "gelu_tanh", "fatrelu"]) "activation",
["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D) @pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
...@@ -40,9 +42,12 @@ def test_act_and_mul( ...@@ -40,9 +42,12 @@ def test_act_and_mul(
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype) x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation == "silu": if activation == "silu_and_mul":
layer = SiluAndMul() layer = SiluAndMul()
fn = torch.ops._C.silu_and_mul fn = torch.ops._C.silu_and_mul
if activation == "mul_and_silu":
layer = MulAndSilu()
fn = torch.ops._C.mul_and_silu
elif activation == "gelu": elif activation == "gelu":
layer = GeluAndMul(approximate="none") layer = GeluAndMul(approximate="none")
fn = torch.ops._C.gelu_and_mul fn = torch.ops._C.gelu_and_mul
...@@ -55,8 +60,9 @@ def test_act_and_mul( ...@@ -55,8 +60,9 @@ def test_act_and_mul(
fn = torch.ops._C.fatrelu_and_mul fn = torch.ops._C.fatrelu_and_mul
out = layer(x) out = layer(x)
ref_out = layer.forward_native(x) ref_out = layer.forward_native(x)
# The SiLU, GELU and FatReLU implementations are equivalent to the native # The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
# PyTorch implementations, so we can do exact comparison. # equivalent to the native PyTorch implementations, so we can do exact
# comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
d = x.shape[-1] // 2 d = x.shape[-1] // 2
......
...@@ -31,9 +31,9 @@ NUM_GEN_SEQS = [7] # Arbitrary values for testing ...@@ -31,9 +31,9 @@ NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
# FlashAttention forward only supports head dimension at most 128 # This should be sync with get_supported_head_sizes() in
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 # vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES = [64, 80, 120, 256] HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True] USE_ALIBI = [False, True]
...@@ -182,7 +182,7 @@ def test_paged_attention( ...@@ -182,7 +182,7 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale # Using default kv_scale
k_scale = v_scale = 1.0 k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Call the paged attention kernel. # Call the paged attention kernel.
output = torch.empty_like(query) output = torch.empty_like(query)
......
from unittest.mock import patch from unittest.mock import Mock, patch
import pytest import pytest
import torch import torch
from tests.kernels.utils import override_backend_env_variable from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.platforms import cpu, cuda, openvino, rocm from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.openvino import OpenVinoPlatform
from vllm.platforms.rocm import RocmPlatform
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
from vllm.platforms import current_platform from vllm.platforms import current_platform
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"] if not current_platform() else ["ROCM_FLASH"]) "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"] if not current_platform() else ["ROCM_FLASH"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"]) @pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
...@@ -21,71 +31,76 @@ def test_env(name: str, device: str, monkeypatch): ...@@ -21,71 +31,76 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable(monkeypatch, name) override_backend_env_variable(monkeypatch, name)
if device == "cpu": if device == "cpu":
with patch("vllm.attention.selector.current_platform", with patch("vllm.attention.selector.current_platform", CpuPlatform()):
cpu.CpuPlatform()): backend = get_attn_backend(16, torch.float16, torch.float16, 16,
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False)
False) assert backend.get_name() == "TORCH_SDPA"
assert backend.name == "TORCH_SDPA"
elif device == "hip": elif device == "hip":
with patch("vllm.attention.selector.current_platform", with patch("vllm.attention.selector.current_platform", RocmPlatform()):
rocm.RocmPlatform()): backend = get_attn_backend(16, torch.float16, torch.float16, 16,
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False)
False) assert backend.get_name() == "ROCM_FLASH"
assert backend.name == "ROCM_FLASH"
elif device == "openvino": elif device == "openvino":
with patch("vllm.attention.selector.current_platform", with patch("vllm.attention.selector.current_platform",
openvino.OpenVinoPlatform()): OpenVinoPlatform()), patch.dict('sys.modules',
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, {'openvino': Mock()}):
False) backend = get_attn_backend(16, torch.float16, torch.float16, 16,
assert backend.name == "OPENVINO" False)
assert backend.get_name() == "OPENVINO"
else: else:
with patch("vllm.attention.selector.current_platform", if name in ["XFORMERS", "FLASHINFER"]:
cuda.CudaPlatform()): with patch("vllm.attention.selector.current_platform",
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, CudaPlatform()):
False) backend = get_attn_backend(16, torch.float16, torch.float16,
assert backend.name == name 16, False)
assert backend.get_name() == name
def test_flash_attn(monkeypatch): def test_flash_attn(monkeypatch):
"""Test FlashAttn validation.""" """Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to # TODO: When testing for v1, pipe in `use_v1` as an argument to
# which_attn_to_use # get_attn_backend
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
# Unsupported CUDA arch # Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)): with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(16, torch.float16, None, 16, False) backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported data type # Unsupported data type
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False) backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type # Unsupported kv cache data type
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False) backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported block size # Unsupported block size
backend = which_attn_to_use(16, torch.float16, None, 8, False) backend = get_attn_backend(16, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.get_name() != STR_FLASH_ATTN_VAL
# flash-attn is not installed # flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}): with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(16, torch.float16, None, 16, False) backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported head size # Unsupported head size
backend = which_attn_to_use(17, torch.float16, None, 16, False) backend = get_attn_backend(17, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.get_name() != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention # Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True) backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
assert backend.name != STR_FLASH_ATTN_VAL assert backend.get_name() != STR_FLASH_ATTN_VAL
def test_invalid_env(monkeypatch): def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid.""" """Ignore the invalid env variable if it is set."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL) override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError): with patch("vllm.attention.selector.current_platform", CudaPlatform()):
which_attn_to_use(16, torch.float16, None, 16, False) backend = get_attn_backend(32, torch.float16, None, 16, False)
assert backend.get_name() == "FLASH_ATTN"
# when block size == 16, backend will fall back to XFORMERS
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() == "XFORMERS"
...@@ -92,8 +92,10 @@ def native_w8a8_block_fp8_matmul(A, ...@@ -92,8 +92,10 @@ def native_w8a8_block_fp8_matmul(A,
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
] ]
B_tiles = [[ B_tiles = [[
B[j * block_n:min((j + 1) * block_n, N), B[
i * block_k:min((i + 1) * block_k, K), ] for i in range(k_tiles) j * block_n:min((j + 1) * block_n, N),
i * block_k:min((i + 1) * block_k, K),
] for i in range(k_tiles)
] for j in range(n_tiles)] ] for j in range(n_tiles)]
C_tiles = [ C_tiles = [
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
...@@ -157,9 +159,9 @@ def setup_cuda(): ...@@ -157,9 +159,9 @@ def setup_cuda():
torch.set_default_device("cuda") torch.set_default_device("cuda")
@pytest.mark.parametrize("num_tokens,d,dtype,group_size,seed", @pytest.mark.parametrize(
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, "num_tokens,d,dtype,group_size,seed",
SEEDS)) itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS))
@torch.inference_mode() @torch.inference_mode()
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
torch.manual_seed(seed) torch.manual_seed(seed)
...@@ -174,9 +176,9 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): ...@@ -174,9 +176,9 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
assert torch.allclose(scale, ref_scale) assert torch.allclose(scale, ref_scale)
@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed", @pytest.mark.parametrize(
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, "M,N,K,block_size,out_dtype,seed",
SEEDS)) itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@torch.inference_mode() @torch.inference_mode()
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
torch.manual_seed(seed) torch.manual_seed(seed)
...@@ -207,9 +209,10 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): ...@@ -207,9 +209,10 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
assert rel_diff < 0.001 assert rel_diff < 0.001
@pytest.mark.parametrize("M,N,K,E,topk,block_size,dtype,seed", @pytest.mark.parametrize(
itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, "M,N,K,E,topk,block_size,dtype,seed",
BLOCK_SIZE, DTYPES, SEEDS)) itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES,
SEEDS))
@torch.inference_mode() @torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
torch.manual_seed(seed) torch.manual_seed(seed)
......
...@@ -210,7 +210,7 @@ def test_paged_attention( ...@@ -210,7 +210,7 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale # Using default kv_scale
k_scale = v_scale = 1.0 k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
tp_rank = 0 tp_rank = 0
# Call the paged attention kernel. # Call the paged attention kernel.
......
...@@ -161,7 +161,7 @@ def test_reshape_and_cache( ...@@ -161,7 +161,7 @@ def test_reshape_and_cache(
cloned_value_cache = value_cache.clone() cloned_value_cache = value_cache.clone()
# Using default kv_scale # Using default kv_scale
k_scale = v_scale = 1.0 k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Call the reshape_and_cache kernel. # Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache, opcheck(torch.ops._C_cache_ops.reshape_and_cache,
...@@ -259,8 +259,8 @@ def test_reshape_and_cache_flash( ...@@ -259,8 +259,8 @@ def test_reshape_and_cache_flash(
del key_caches del key_caches
del value_caches del value_caches
k_scale = key.amax().item() / 256 k_scale = (key.amax() / 256.0).to(torch.float32)
v_scale = value.amax().item() / 256 v_scale = (value.amax() / 256.0).to(torch.float32)
# Clone the KV caches. # Clone the KV caches.
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
...@@ -285,12 +285,12 @@ def test_reshape_and_cache_flash( ...@@ -285,12 +285,12 @@ def test_reshape_and_cache_flash(
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache, ops.convert_fp8(result_key_cache,
key_cache, key_cache,
k_scale, k_scale.item(),
kv_dtype=kv_cache_dtype) kv_dtype=kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache, ops.convert_fp8(result_value_cache,
value_cache, value_cache,
v_scale, v_scale.item(),
kv_dtype=kv_cache_dtype) kv_dtype=kv_cache_dtype)
# Run the reference implementation. # Run the reference implementation.
......
from typing import List, Optional, Tuple
import pytest
import torch
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (cascade_attention,
merge_attn_states)
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
is_fa_version_supported)
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 192, 256]
BLOCK_SIZES = [16]
DTYPES = [torch.float16, torch.bfloat16]
@pytest.mark.parametrize("num_tokens", [1, 39, 16912])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_merge_kernel(
num_tokens: int,
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
current_platform.seed_everything(0)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
# Prepare inputs.
prefix_output = torch.randn(num_tokens,
num_query_heads,
head_size,
dtype=dtype)
suffix_output = torch.randn(num_tokens,
num_query_heads,
head_size,
dtype=dtype)
prefix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32)
suffix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32)
# Run the kernel.
output = torch.empty(num_tokens, num_query_heads, head_size, dtype=dtype)
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
suffix_lse)
# Reference implementation.
max_lse = torch.maximum(prefix_lse, suffix_lse)
p_lse = torch.exp(prefix_lse - max_lse)
s_lse = torch.exp(suffix_lse - max_lse)
p_scale = p_lse / (p_lse + s_lse)
s_scale = s_lse / (p_lse + s_lse)
p_scale = p_scale.transpose(0, 1).unsqueeze(2)
s_scale = s_scale.transpose(0, 1).unsqueeze(2)
ref_output = p_scale * prefix_output + s_scale * suffix_output
ref_output = ref_output.to(dtype)
# Compare the results.
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
CASES = [
# Case 1. A general case.
([(129, 871), (18, 280), (37, 988), (1023, 2304), (1, 257)], 256),
# Case 2. Flash-decoding case.
([(1, 1023), (1, 879), (1, 778), (1, 1777)] * 100, 512),
]
@pytest.mark.parametrize("seq_lens_and_common_prefix", CASES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("soft_cap", [None, 50])
@pytest.mark.parametrize("num_blocks", [2048])
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode()
def test_cascade(
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
fa_version: int,
) -> None:
torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
current_platform.seed_everything(0)
window_size = (-1, -1)
scale = head_size**-0.5
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
seq_lens, common_prefix_len = seq_lens_and_common_prefix
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
total_num_query_tokens = sum(query_lens)
query = torch.randn(total_num_query_tokens,
num_query_heads,
head_size,
dtype=dtype)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
assert common_prefix_len > 0
assert common_prefix_len % block_size == 0
num_common_kv_blocks = common_prefix_len // block_size
# Make sure the first `num_common_kv_blocks` blocks are the same.
block_tables[:, :num_common_kv_blocks] = \
block_tables[0, :num_common_kv_blocks]
# Run the regular attention.
ref_output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
seqused_k=kv_lens_tensor,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
)
# Run cascade attention.
assert all(common_prefix_len < kv_len for kv_len in kv_lens)
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens],
dtype=torch.int32)
prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32)
suffix_kv_lens = kv_lens_tensor - common_prefix_len
output = torch.empty_like(query)
cascade_attention(
output=output,
query=query,
key_cache=key_cache,
value_cache=value_cache,
cu_query_lens=cu_query_lens,
max_query_len=max_query_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
max_kv_len=max_kv_len,
softmax_scale=scale,
alibi_slopes=None,
sliding_window=window_size,
logits_soft_cap=soft_cap if soft_cap is not None else 0,
block_table=block_tables,
common_prefix_len=common_prefix_len,
fa_version=fa_version,
)
# Compare the results.
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
Run `pytest tests/kernels/test_cutlass.py`. Run `pytest tests/kernels/test_cutlass.py`.
""" """
from typing import Optional, Type from typing import Type, Optional
import pytest import pytest
import torch import torch
...@@ -10,6 +10,9 @@ import torch ...@@ -10,6 +10,9 @@ import torch
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv
from .utils import baseline_scaled_mm, to_fp8, to_int8
MNK_FACTORS = [ MNK_FACTORS = [
(1, 256, 128), (1, 256, 128),
...@@ -37,20 +40,15 @@ CUDA_DEVICES = [ ...@@ -37,20 +40,15 @@ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
# -1 means full extent in that dimension
TENSORWISE_GROUP_SHAPE = (-1, -1)
PER_TOKEN_GROUP_SHAPE = (1, -1)
PER_OUT_CH_GROUP_SHAPE = (-1, 1)
capability = current_platform.get_device_capability() capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def rand_int8(shape: tuple, device: str = "cuda"): def rand_int8(shape: tuple, device: str = "cuda"):
return to_int8(torch.rand(shape, device=device) * 255 - 128) return to_int8(torch.rand(shape, device=device) * 255 - 128)
...@@ -66,14 +64,22 @@ def baseline_scaled_mm(a: torch.Tensor, ...@@ -66,14 +64,22 @@ def baseline_scaled_mm(a: torch.Tensor,
if bias is not None: if bias is not None:
output = output + bias output = output + bias
return output def group_scale_helper(shape, group_shape):
return [shape[i] if s < 0 else s for i, s in enumerate(group_shape)]
def scale_shape(shape, group_shape):
assert len(shape) == len(group_shape)
group_shape = group_scale_helper(shape, group_shape)
return tuple(
cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
def cutlass_fp8_gemm_helper(m: int, def cutlass_fp8_gemm_helper(m: int,
n: int, n: int,
k: int, k: int,
per_token_act_quant: bool, a_scale_group_shape: tuple,
per_out_channel_weight_quant: bool, b_scale_group_shape: tuple,
use_bias: bool, use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16, out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"): device: str = "cuda"):
...@@ -82,13 +88,17 @@ def cutlass_fp8_gemm_helper(m: int, ...@@ -82,13 +88,17 @@ def cutlass_fp8_gemm_helper(m: int,
a = to_fp8(torch.randn((m, k), device=device)) a = to_fp8(torch.randn((m, k), device=device))
b = to_fp8(torch.randn((n, k), device=device).t()) b = to_fp8(torch.randn((n, k), device=device).t())
m_a_scales = m if per_token_act_quant else 1 a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
n_b_scales = n if per_out_channel_weight_quant else 1 b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
# make scales M-major for blockwise quant, doesn't affect 1D scales
scale_a = scale_a.t().contiguous().t()
# make scales K-major for blockwise quant, doesn't affect 1D scales
scale_b = scale_b.t().contiguous().t()
scale_a = (torch.randn((m_a_scales, 1), device=device,
dtype=torch.float32))
scale_b = (torch.randn((1, n_b_scales), device=device,
dtype=torch.float32))
if use_bias: if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
else: else:
...@@ -106,8 +116,8 @@ def cutlass_fp8_gemm_helper(m: int, ...@@ -106,8 +116,8 @@ def cutlass_fp8_gemm_helper(m: int,
def cutlass_int8_gemm_helper(m: int, def cutlass_int8_gemm_helper(m: int,
n: int, n: int,
k: int, k: int,
per_token_act_quant: bool, a_scale_group_shape: tuple,
per_out_channel_weight_quant: bool, b_scale_group_shape: tuple,
use_bias: bool, use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16, out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"): device: str = "cuda"):
...@@ -116,13 +126,11 @@ def cutlass_int8_gemm_helper(m: int, ...@@ -116,13 +126,11 @@ def cutlass_int8_gemm_helper(m: int,
a = to_int8(torch.randn((m, k), device=device) * 5) a = to_int8(torch.randn((m, k), device=device) * 5)
b = to_int8(torch.randn((n, k), device=device).t() * 5) b = to_int8(torch.randn((n, k), device=device).t() * 5)
m_a_scales = m if per_token_act_quant else 1 a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
n_b_scales = n if per_out_channel_weight_quant else 1 b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
scale_a = (torch.randn((m_a_scales, 1), device=device, scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
dtype=torch.float32)) scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
scale_b = (torch.randn((1, n_b_scales), device=device,
dtype=torch.float32))
if use_bias: if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
...@@ -139,85 +147,139 @@ def cutlass_int8_gemm_helper(m: int, ...@@ -139,85 +147,139 @@ def cutlass_int8_gemm_helper(m: int,
# @pytest.mark.parametrize("m,n,k", MNK_FACTORS) # @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
# @pytest.mark.parametrize("per_act_token", [True, False]) # @pytest.mark.parametrize("a_scale_group_shape",
# @pytest.mark.parametrize("per_out_ch", [True, False]) # [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False]) # @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.skipif(not current_platform.has_device_capability(89), # @pytest.mark.skipif(not current_platform.has_device_capability(89),
# reason="FP8 is not supported on this GPU type.") # reason="FP8 is not supported on this GPU type.")
# def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, # def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape,
# per_out_ch: bool, use_bias: bool): # b_scale_group_shape, use_bias: bool):
# cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias) # cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
# use_bias)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
[((1, 128), (128, 128))])
@pytest.mark.parametrize("use_bias", [False])
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="FP8 blockwise is not supported on this GPU type.")
def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
a_scale_group_shape,
b_scale_group_shape, use_bias: bool):
if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0:
return
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
return
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("a_scale_group_shape",
@pytest.mark.parametrize("per_out_ch", [True, False]) [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool, def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
per_out_ch: bool, use_bias: bool): b_scale_group_shape, use_bias: bool):
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias) cutlass_int8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias)
@pytest.mark.parametrize("per_act_token", [True]) @pytest.mark.parametrize("a_scale_group_shape",
@pytest.mark.parametrize("per_out_ch", [True]) [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("out_dtype", [ torch.float16]) #torch.bfloat16, @pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: Type[torch.dtype], out_dtype: Type[torch.dtype],
use_bias: bool): use_bias: bool):
cutlass_int8_gemm_helper(512, cutlass_int8_gemm_helper(512,
512, 512,
512, 512,
per_act_token, a_scale_group_shape,
per_out_ch, b_scale_group_shape,
use_bias, use_bias,
out_dtype=out_dtype) out_dtype=out_dtype)
# @pytest.mark.parametrize("per_act_token", [True, False]) # @pytest.mark.parametrize("a_scale_group_shape",
# @pytest.mark.parametrize("per_out_ch", [True, False]) # [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) # @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
# @pytest.mark.parametrize("use_bias", [True, False]) # @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.skipif(not current_platform.has_device_capability(89), # @pytest.mark.skipif(not current_platform.has_device_capability(89),
# reason="FP8 is not supported on this GPU type.") # reason="FP8 is not supported on this GPU type.")
# def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, # def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
# b_scale_group_shape,
# out_dtype: Type[torch.dtype], # out_dtype: Type[torch.dtype],
# use_bias: bool): # use_bias: bool):
# cutlass_fp8_gemm_helper(512, # cutlass_fp8_gemm_helper(512,
# 512, # 512,
# 512, # 512,
# per_act_token, # a_scale_group_shape,
# per_out_ch, # b_scale_group_shape,
# use_bias, # use_bias,
# out_dtype=out_dtype) # out_dtype=out_dtype)
# @pytest.mark.parametrize("per_act_token", [True, False]) # @pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
# @pytest.mark.parametrize("per_out_ch", [True, False]) # [((1, 128), (128, 128))])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
# @pytest.mark.parametrize("use_bias", [False])
# @pytest.mark.skipif(not current_platform.has_device_capability(90),
# reason="FP8 blockwise is not supported on this GPU type.")
# def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
# b_scale_group_shape,
# out_dtype: Type[torch.dtype],
# use_bias: bool):
# cutlass_fp8_gemm_helper(512,
# 512,
# 512,
# a_scale_group_shape,
# b_scale_group_shape,
# use_bias,
# out_dtype=out_dtype)
# @pytest.mark.parametrize("a_scale_group_shape",
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False]) # @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.skipif(not current_platform.has_device_capability(89), # @pytest.mark.skipif(not current_platform.has_device_capability(89),
# reason="FP8 is not supported on this GPU type.") # reason="FP8 is not supported on this GPU type.")
# def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool, # def test_cutlass_fp8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
# use_bias: bool, device: str): # use_bias: bool, device: str):
# cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias, # cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape,
# torch.bfloat16, device) # b_scale_group_shape, use_bias, torch.bfloat16,
# device)
# @pytest.mark.parametrize("per_act_token", [True, False]) # @pytest.mark.parametrize("a_scale_group_shape",
# @pytest.mark.parametrize("per_out_ch", [True, False]) # [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False]) # @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
# def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool, # def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
# use_bias: bool, device: str): use_bias: bool, device: str):
# cutlass_int8_gemm_helper(512, cutlass_int8_gemm_helper(512,
# 512, 512,
# 512, 512,
# per_act_token, a_scale_group_shape,
# per_out_ch, b_scale_group_shape,
# use_bias, use_bias,
# out_dtype=torch.bfloat16, out_dtype=torch.bfloat16,
# device=device) device=device)
# For the following two tests: # For the following two tests:
...@@ -225,28 +287,32 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, ...@@ -225,28 +287,32 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
# of a large power of two. In any case, the kernel will have a naive fallback # of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the # when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it. # kernel must handle any M thrown at it.
# @pytest.mark.parametrize("per_act_token", [True, False]) # @pytest.mark.parametrize("a_scale_group_shape",
# @pytest.mark.parametrize("per_out_ch", [True, False]) # [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False]) # @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.skipif(not current_platform.has_device_capability(89), # @pytest.mark.skipif(not current_platform.has_device_capability(89),
# reason="FP8 is not supported on this GPU type.") # reason="FP8 is not supported on this GPU type.")
# def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool, # def test_cutlass_fp8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
# use_bias: bool): # use_bias: bool):
# for nk in range(32, 128, 32): # for nk in range(32, 128, 32):
# for m in range(1, 128): # for m in range(1, 128):
# cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch, # cutlass_fp8_gemm_helper(m, nk, nk, a_scale_group_shape,
# use_bias) # b_scale_group_shape, use_bias)
# @pytest.mark.parametrize("per_act_token", [True, False]) # @pytest.mark.parametrize("a_scale_group_shape",
# @pytest.mark.parametrize("per_out_ch", [True, False]) # [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False]) # @pytest.mark.parametrize("use_bias", [True, False])
# def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool, # def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
# use_bias: bool): use_bias: bool):
# for nk in range(32, 128, 32): for nk in range(32, 128, 32):
# for m in range(1, 128): for m in range(1, 128):
# cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch, cutlass_int8_gemm_helper(m, nk, nk, a_scale_group_shape,
# use_bias) b_scale_group_shape, use_bias)
# @pytest.mark.parametrize("m", [32, 64, 128]) # @pytest.mark.parametrize("m", [32, 64, 128])
...@@ -304,38 +370,39 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, ...@@ -304,38 +370,39 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
# @pytest.mark.parametrize("n", [16, 32, 64]) # @pytest.mark.parametrize("n", [16, 32, 64])
# @pytest.mark.parametrize("k", [64, 128, 256]) # @pytest.mark.parametrize("k", [64, 128, 256])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) # @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
# @pytest.mark.skip # @pytest.mark.parametrize("use_bias", [True, False])
# def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, # @pytest.mark.parametrize("azp_per_token", [True, False])
# out_dtype: torch.dtype): # def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
# # Currently, the test is failing because folding azp into # use_bias: bool, azp_per_token: bool):
# # 16-bit bias loses too much precision # m_azp = m if azp_per_token else 1
# scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 # scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
# scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10 # scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
# aq_i8 = rand_int8((m, k)) # aq_i8 = rand_int8((m, k))
# bq_i8 = rand_int8((n, k)).t() # aq_i32 = aq_i8.to(dtype=torch.int32)
# aq_f32 = aq_i8.to(dtype=torch.float32)
# aq_i32 = aq_i8.to(dtype=torch.int32)
# bq_i32 = bq_i8.to(dtype=torch.int32) # bq_i8 = rand_int8((n, k)).t()
# bq_i32 = bq_i8.to(dtype=torch.int32)
# aq_f32 = aq_i8.to(dtype=torch.float32) # bq_f32 = bq_i8.to(dtype=torch.float32)
# bq_f32 = bq_i8.to(dtype=torch.float32) # b_dq = scale_b * bq_f32
# b_dq = scale_b * bq_f32 # azp_a = torch.rand(
# (m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
# azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5 # azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
# azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8) # azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
# azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
# a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
# a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32) # torch.testing.assert_close(a_dq,
# torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a) # scale_a * aq_f32 - azp_a,
# rtol=1e-4,
# baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype) # atol=1e-3)
# if use_bias:
# bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
# else:
# bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
# J = torch.ones((1, k), device="cuda", dtype=torch.float32)
# azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
# assert azp_bias.shape == (1, n)
# assert azp_bias[0, :].shape == (n, )
# baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * ( # baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
# (aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to( # (aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
......
...@@ -2,16 +2,19 @@ ...@@ -2,16 +2,19 @@
Run `pytest tests/kernels/test_semi_structured.py`. Run `pytest tests/kernels/test_semi_structured.py`.
""" """
from typing import Optional, Tuple, Type from typing import Tuple, Type
import pytest import pytest
import torch import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported) sparse_cutlass_supported)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .utils import baseline_scaled_mm, to_fp8, to_int8
CUDA_DEVICES = [ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
...@@ -20,20 +23,6 @@ capability = current_platform.get_device_capability() ...@@ -20,20 +23,6 @@ capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def rand_int8(shape: tuple, device: str = "cuda"):
return to_int8(torch.rand(shape, device=device) * 255 - 128)
def to_bf16(tensor: torch.Tensor) -> torch.Tensor: def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.bfloat16) return tensor.to(dtype=torch.bfloat16)
...@@ -90,22 +79,8 @@ def make_rand_sparse_tensors( ...@@ -90,22 +79,8 @@ def make_rand_sparse_tensors(
return b_compressed, e, a, b return b_compressed, e, a, b
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = (scale_a * (scale_b * (torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
if bias is not None:
output = output + bias
return output
@pytest.mark.skipif(not sparse_cutlass_supported(), @pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.") reason="Sparse CUTLASS is not supported on this GPU type.")
# Test working with a subset of A and B for sparse matmul # Test working with a subset of A and B for sparse matmul
def test_cutlass_sparse_subset(): def test_cutlass_sparse_subset():
...@@ -132,3 +107,108 @@ def test_cutlass_sparse_subset(): ...@@ -132,3 +107,108 @@ def test_cutlass_sparse_subset():
out_dtype=torch.bfloat16) out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
MNK_FACTORS = [
(1, 256, 128),
(1, 16384, 1024),
(1, 24576, 512),
(16, 256, 512),
(16, 16384, 128),
(16, 24576, 4096),
(32, 8192, 4096),
(32, 16384, 4096),
(33, 1024, 1024),
(33, 8192, 128),
(64, 2048, 512),
(64, 16384, 1024),
(100, 8192, 512),
(128, 32768, 4096),
(256, 4096, 4096),
(512, 256, 1024),
(512, 8192, 4096),
(512, 16384, 128),
(512, 24576, 128),
]
# Test working with a subset of A and B for sparse matmul
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype]):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=dtype)
baseline = F.linear(a, b.T)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1e-2)
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool, use_bias: bool):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
...@@ -13,8 +13,7 @@ import pytest ...@@ -13,8 +13,7 @@ import pytest
import torch import torch
from tests.kernels.utils import * from tests.kernels.utils import *
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, from vllm.attention import Attention, AttentionMetadata, AttentionType
AttentionType)
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, _cached_get_attn_backend, from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager) global_force_attn_backend_context_manager)
...@@ -64,6 +63,7 @@ class TestPoint(NamedTuple): ...@@ -64,6 +63,7 @@ class TestPoint(NamedTuple):
max_dec_seq_len: int max_dec_seq_len: int
max_enc_seq_len: int max_enc_seq_len: int
num_blocks: int num_blocks: int
attn_type: AttentionType
class TestResources(NamedTuple): class TestResources(NamedTuple):
...@@ -96,7 +96,6 @@ class TestResources(NamedTuple): ...@@ -96,7 +96,6 @@ class TestResources(NamedTuple):
''' '''
scale: float scale: float
attn_backend: AttentionBackend
attn: Attention attn: Attention
kv_cache: torch.Tensor kv_cache: torch.Tensor
...@@ -129,26 +128,33 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources: ...@@ -129,26 +128,33 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
''' '''
scale = float(1.0 / (test_pt.head_size**0.5)) scale = float(1.0 / (test_pt.head_size**0.5))
attn_backend = make_backend(test_pt.backend_name)
attn = Attention( attn = Attention(
test_pt.num_heads, test_pt.num_heads,
test_pt.head_size, test_pt.head_size,
scale=scale, scale=scale,
prefix=f"{test_pt.attn_type}",
attn_type=test_pt.attn_type,
) )
if test_pt.num_blocks is None or test_pt.num_heads is None: if test_pt.num_blocks is None or test_pt.num_heads is None:
# Caller does not require a KV cache # Caller does not require a KV cache
return TestResources( return TestResources(
scale, attn_backend, attn, scale, attn,
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE)) torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
# Construct KV cache # Construct KV cache
kv_cache = make_kv_cache(test_pt.num_blocks, if test_pt.attn_type in (AttentionType.DECODER,
test_pt.num_heads, AttentionType.ENCODER_DECODER):
test_pt.head_size, kv_cache = make_kv_cache(test_pt.num_blocks,
test_pt.block_size, test_pt.num_heads,
device=CUDA_DEVICE, test_pt.head_size,
backend=test_pt.backend_name) test_pt.block_size,
return TestResources(scale, attn_backend, attn, kv_cache) device=CUDA_DEVICE,
backend=test_pt.backend_name)
else:
kv_cache = torch.tensor([])
attn.kv_cache = [kv_cache]
return TestResources(scale, attn, kv_cache)
def _encoder_attn_setup( def _encoder_attn_setup(
...@@ -193,6 +199,7 @@ def _encoder_attn_setup( ...@@ -193,6 +199,7 @@ def _encoder_attn_setup(
_, _,
max_q_seq_len, max_q_seq_len,
_, _,
_,
) = test_pt ) = test_pt
scale = test_rsrcs.scale scale = test_rsrcs.scale
...@@ -301,6 +308,7 @@ def _decoder_attn_setup( ...@@ -301,6 +308,7 @@ def _decoder_attn_setup(
max_q_seq_len, max_q_seq_len,
_, _,
_, _,
_,
) = test_pt ) = test_pt
scale = test_rsrcs.scale scale = test_rsrcs.scale
...@@ -488,6 +496,7 @@ def _enc_dec_cross_attn_setup_reuses_query( ...@@ -488,6 +496,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
max_decoder_seq_len, max_decoder_seq_len,
max_encoder_seq_len, max_encoder_seq_len,
_, _,
_,
) = test_pt ) = test_pt
scale = test_rsrcs.scale scale = test_rsrcs.scale
...@@ -622,7 +631,6 @@ def _run_encoder_attention_test( ...@@ -622,7 +631,6 @@ def _run_encoder_attention_test(
& attn_metadata & attn_metadata
''' '''
assert attn_metadata.num_decode_tokens == 0 assert attn_metadata.num_decode_tokens == 0
attn_type = AttentionType.ENCODER
packed_qkv = encoder_test_params.packed_qkvo.packed_qkv packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
assert packed_qkv is not None assert packed_qkv is not None
with set_forward_context(attn_metadata, vllm_config): with set_forward_context(attn_metadata, vllm_config):
...@@ -635,14 +643,11 @@ def _run_encoder_attention_test( ...@@ -635,14 +643,11 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape. # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view( reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size) -1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query, return attn.forward(
packed_qkv.key, reshaped_query, packed_qkv.key, packed_qkv.value,
packed_qkv.value, torch.tensor([],
torch.tensor([], dtype=torch.float32,
dtype=torch.float32, device=packed_qkv.query.device), attn_metadata)
device=packed_qkv.query.device),
attn_metadata,
attn_type=attn_type)
def _run_decoder_self_attention_test( def _run_decoder_self_attention_test(
...@@ -675,7 +680,6 @@ def _run_decoder_self_attention_test( ...@@ -675,7 +680,6 @@ def _run_decoder_self_attention_test(
* Attention.forward() applied to packed_{query,key,value}, kv_cache * Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata & attn_metadata
''' '''
attn_type = AttentionType.DECODER
attn = test_rsrcs.attn attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache kv_cache = test_rsrcs.kv_cache
packed_qkv = decoder_test_params.packed_qkvo.packed_qkv packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
...@@ -690,12 +694,8 @@ def _run_decoder_self_attention_test( ...@@ -690,12 +694,8 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape. # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = packed_qkv.query.view( reshaped_query = packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size) -1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query, return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value,
packed_qkv.key, kv_cache, attn_metadata)
packed_qkv.value,
kv_cache,
attn_metadata,
attn_type=attn_type)
def _run_encoder_decoder_cross_attention_test( def _run_encoder_decoder_cross_attention_test(
...@@ -742,7 +742,6 @@ def _run_encoder_decoder_cross_attention_test( ...@@ -742,7 +742,6 @@ def _run_encoder_decoder_cross_attention_test(
''' '''
assert decoder_test_params.packed_qkvo.packed_qkv is not None assert decoder_test_params.packed_qkvo.packed_qkv is not None
attn_type = AttentionType.ENCODER_DECODER
attn = test_rsrcs.attn attn = test_rsrcs.attn
kv_cache = test_rsrcs.kv_cache kv_cache = test_rsrcs.kv_cache
if cross_test_params is None: if cross_test_params is None:
...@@ -762,12 +761,8 @@ def _run_encoder_decoder_cross_attention_test( ...@@ -762,12 +761,8 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape. # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
-1, test_pt.num_heads * test_pt.head_size) -1, test_pt.num_heads * test_pt.head_size)
return attn.forward(reshaped_query, return attn.forward(reshaped_query, key, value, kv_cache,
key, attn_metadata)
value,
kv_cache,
attn_metadata,
attn_type=attn_type)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -839,7 +834,7 @@ def test_encoder_only( ...@@ -839,7 +834,7 @@ def test_encoder_only(
# is not part of this test # is not part of this test
test_pt = TestPoint(num_heads, head_size, attn_backend.name, test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len, batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096) max_enc_seq_len, 4096, AttentionType.ENCODER)
# Attention scale factor, attention backend instance, attention wrapper # Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init # instance, KV cache init
...@@ -855,7 +850,7 @@ def test_encoder_only( ...@@ -855,7 +850,7 @@ def test_encoder_only(
# Shared prefill metadata structure # Shared prefill metadata structure
prephase_attn_metadata: AttentionMetadata = make_test_metadata( prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend, attn_backend,
True, True,
None, None,
decoder_test_params=None, decoder_test_params=None,
...@@ -961,20 +956,29 @@ def test_e2e_enc_dec_attn( ...@@ -961,20 +956,29 @@ def test_e2e_enc_dec_attn(
# Note: KV cache size of 4096 is arbitrary & chosen intentionally # Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size # to be more than necessary, since exceeding the kv cache size
# is not part of this test # is not part of this test
test_pt = TestPoint(num_heads, head_size, attn_backend.name, enc_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len, batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096) max_enc_seq_len, 4096, AttentionType.ENCODER)
enc_dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096,
AttentionType.ENCODER_DECODER)
dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
batch_size, block_size, max_dec_seq_len,
max_enc_seq_len, 4096, AttentionType.DECODER)
# Attention scale factor, attention backend instance, attention wrapper # Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init # instance, KV cache init
vllm_config = VllmConfig() vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
test_rsrcs = _make_test_resources(test_pt) enc_test_rsrcs = _make_test_resources(enc_test_pt)
enc_dec_test_rsrcs = _make_test_resources(enc_dec_test_pt)
dec_test_rsrcs = _make_test_resources(dec_test_pt)
# Construct encoder attention test params (only used # Construct encoder attention test params (only used
# during prefill) # during prefill)
enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) enc_test_params = _encoder_attn_setup(enc_test_pt, enc_test_rsrcs)
# Construct Decoder self-attention prefill-phase & decode-phase # Construct Decoder self-attention prefill-phase & decode-phase
# test params, including query/key/value tensors, decoder self-attention # test params, including query/key/value tensors, decoder self-attention
...@@ -987,7 +991,7 @@ def test_e2e_enc_dec_attn( ...@@ -987,7 +991,7 @@ def test_e2e_enc_dec_attn(
prephase_dec_test_params, prephase_dec_test_params,
decphase_dec_test_params, decphase_dec_test_params,
cross_block_base_addr, cross_block_base_addr,
) = _decoder_attn_setup(test_pt, test_rsrcs) ) = _decoder_attn_setup(dec_test_pt, dec_test_rsrcs)
# Construct encoder/decoder cross-attention prefill-phase # Construct encoder/decoder cross-attention prefill-phase
# & decode-phase test params, including key/value tensors, # & decode-phase test params, including key/value tensors,
...@@ -1000,14 +1004,14 @@ def test_e2e_enc_dec_attn( ...@@ -1000,14 +1004,14 @@ def test_e2e_enc_dec_attn(
dec_qkv, dec_qkv,
enc_test_params, enc_test_params,
prephase_dec_test_params, prephase_dec_test_params,
test_pt, enc_dec_test_pt,
test_rsrcs, enc_dec_test_rsrcs,
block_base_addr=cross_block_base_addr) block_base_addr=cross_block_base_addr)
# Shared prefill metadata structure # Shared prefill metadata structure
assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
prephase_attn_metadata: AttentionMetadata = make_test_metadata( prephase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend, attn_backend,
True, True,
prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
decoder_test_params=prephase_dec_test_params, decoder_test_params=prephase_dec_test_params,
...@@ -1017,10 +1021,10 @@ def test_e2e_enc_dec_attn( ...@@ -1017,10 +1021,10 @@ def test_e2e_enc_dec_attn(
# PREFILL: encoder attention # PREFILL: encoder attention
enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, enc_pckd_act_out = _run_encoder_attention_test(enc_test_rsrcs.attn,
enc_test_params, enc_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt, test_pt=enc_test_pt,
vllm_config=vllm_config) vllm_config=vllm_config)
# - Is encoder attention result correct? # - Is encoder attention result correct?
...@@ -1030,10 +1034,10 @@ def test_e2e_enc_dec_attn( ...@@ -1030,10 +1034,10 @@ def test_e2e_enc_dec_attn(
# PREFILL: decoder self-attention test # PREFILL: decoder self-attention test
prephase_dec_pckd_act_out = _run_decoder_self_attention_test( prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, dec_test_rsrcs,
prephase_dec_test_params, prephase_dec_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt, test_pt=dec_test_pt,
vllm_config=vllm_config) vllm_config=vllm_config)
# - Is prefill decoder self-attention correct? # - Is prefill decoder self-attention correct?
...@@ -1044,11 +1048,11 @@ def test_e2e_enc_dec_attn( ...@@ -1044,11 +1048,11 @@ def test_e2e_enc_dec_attn(
# PREFILL: encoder/decoder cross-attention test # PREFILL: encoder/decoder cross-attention test
prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, enc_dec_test_rsrcs,
prephase_dec_test_params, prephase_dec_test_params,
prephase_cross_test_params, prephase_cross_test_params,
prephase_attn_metadata, prephase_attn_metadata,
test_pt=test_pt, test_pt=enc_dec_test_pt,
vllm_config=vllm_config) vllm_config=vllm_config)
# - Is prefill encoder/decoder cross-attention correct? # - Is prefill encoder/decoder cross-attention correct?
...@@ -1059,7 +1063,7 @@ def test_e2e_enc_dec_attn( ...@@ -1059,7 +1063,7 @@ def test_e2e_enc_dec_attn(
# DECODE: build decode-phase attention metadata # DECODE: build decode-phase attention metadata
decphase_attn_metadata: AttentionMetadata = make_test_metadata( decphase_attn_metadata: AttentionMetadata = make_test_metadata(
test_rsrcs.attn_backend, attn_backend,
False, False,
dec_qkv.q_seq_lens, dec_qkv.q_seq_lens,
decoder_test_params=decphase_dec_test_params, decoder_test_params=decphase_dec_test_params,
...@@ -1070,10 +1074,10 @@ def test_e2e_enc_dec_attn( ...@@ -1070,10 +1074,10 @@ def test_e2e_enc_dec_attn(
# DECODE: decoder self-attention test # DECODE: decoder self-attention test
decphase_dec_pckd_act_out = _run_decoder_self_attention_test( decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
test_rsrcs, dec_test_rsrcs,
decphase_dec_test_params, decphase_dec_test_params,
decphase_attn_metadata, decphase_attn_metadata,
test_pt=test_pt, test_pt=dec_test_pt,
vllm_config=vllm_config) vllm_config=vllm_config)
# - Is decode-phase decoder self-attention correct? # - Is decode-phase decoder self-attention correct?
...@@ -1084,11 +1088,11 @@ def test_e2e_enc_dec_attn( ...@@ -1084,11 +1088,11 @@ def test_e2e_enc_dec_attn(
# DECODE: encoder/decoder cross-attention test # DECODE: encoder/decoder cross-attention test
decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
test_rsrcs, enc_dec_test_rsrcs,
decphase_dec_test_params, decphase_dec_test_params,
None, None,
decphase_attn_metadata, decphase_attn_metadata,
test_pt=test_pt, test_pt=enc_dec_test_pt,
vllm_config=vllm_config) vllm_config=vllm_config)
# - Is decode-phase encoder/decoder cross-attention correct? # - Is decode-phase encoder/decoder cross-attention correct?
......
...@@ -5,11 +5,14 @@ import torch ...@@ -5,11 +5,14 @@ import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform(): if current_platform():
import flash_attn import flash_attn
else: else:
from vllm.vllm_flash_attn import (flash_attn_varlen_func, from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_with_kvcache) flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)
NUM_HEADS = [(4, 4), (8, 2), (16, 2)] NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
...@@ -84,6 +87,7 @@ if not current_platform(): ...@@ -84,6 +87,7 @@ if not current_platform():
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256]) @pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode() @torch.inference_mode()
def test_flash_attn_with_paged_kv( def test_flash_attn_with_paged_kv(
use_out: bool, use_out: bool,
...@@ -95,8 +99,13 @@ if not current_platform(): ...@@ -95,8 +99,13 @@ if not current_platform():
soft_cap: Optional[float], soft_cap: Optional[float],
num_blocks: int, num_blocks: int,
sliding_window: Optional[int], sliding_window: Optional[int],
fa_version: int,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(kv_lens) num_seqs = len(kv_lens)
num_query_heads = num_heads[0] num_query_heads = num_heads[0]
...@@ -135,6 +144,7 @@ if not current_platform(): ...@@ -135,6 +144,7 @@ if not current_platform():
cache_seqlens=kv_lens_tensor, cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size, window_size=window_size,
fa_version=fa_version,
) )
output = output if not use_out else out output = output if not use_out else out
output = output.squeeze(1) output = output.squeeze(1)
...@@ -150,9 +160,8 @@ if not current_platform(): ...@@ -150,9 +160,8 @@ if not current_platform():
sliding_window=sliding_window) sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
@pytest.mark.parametrize("use_out", [True, False]) @pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("seq_lens", @pytest.mark.parametrize("seq_lens",
[[(1, 1328), (5, 18), [[(1, 1328), (5, 18),
...@@ -164,6 +173,7 @@ if not current_platform(): ...@@ -164,6 +173,7 @@ if not current_platform():
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("fa_version", [2, 3])
@torch.inference_mode() @torch.inference_mode()
def test_varlen_with_paged_kv( def test_varlen_with_paged_kv(
use_out: bool, use_out: bool,
...@@ -175,8 +185,12 @@ def test_varlen_with_paged_kv( ...@@ -175,8 +185,12 @@ def test_varlen_with_paged_kv(
block_size: int, block_size: int,
soft_cap: Optional[float], soft_cap: Optional[float],
num_blocks: int, num_blocks: int,
fa_version: int,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
...@@ -206,6 +220,7 @@ def test_varlen_with_paged_kv( ...@@ -206,6 +220,7 @@ def test_varlen_with_paged_kv(
cu_kv_lens = torch.tensor([0] + kv_lens, cu_kv_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0, dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32) dtype=torch.int32)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0, block_tables = torch.randint(0,
...@@ -230,6 +245,7 @@ def test_varlen_with_paged_kv( ...@@ -230,6 +245,7 @@ def test_varlen_with_paged_kv(
window_size=window_size, window_size=window_size,
block_table=block_tables, block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
# fa_version=fa_version,
) )
else: else:
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
...@@ -238,7 +254,7 @@ def test_varlen_with_paged_kv( ...@@ -238,7 +254,7 @@ def test_varlen_with_paged_kv(
v=value_cache, v=value_cache,
out=out, out=out,
cu_seqlens_q=cu_query_lens, cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens, seqused_k=kv_lens,
max_seqlen_q=max_query_len, max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len, max_seqlen_k=max_kv_len,
softmax_scale=scale, softmax_scale=scale,
...@@ -246,7 +262,9 @@ def test_varlen_with_paged_kv( ...@@ -246,7 +262,9 @@ def test_varlen_with_paged_kv(
window_size=window_size, window_size=window_size,
block_table=block_tables, block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version,
) )
output = output if not use_out else out output = output if not use_out else out
ref_output = ref_paged_attn( ref_output = ref_paged_attn(
......
"""
Test:
* Tests for MultiHeadAttention layer
"""
from unittest.mock import patch
import pytest
import torch
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.selector import _Backend, _cached_get_attn_backend
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_mha_attn_platform(device: str):
"""
Test the attention selector between different platform and device.
"""
torch.set_default_dtype(torch.float16)
if device == "cpu":
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
elif device == "hip":
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
else:
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.XFORMERS
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.XFORMERS
def ref_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
"""
Native implementation of scaled dot product attention without mask:
- query, key, value: [batch_size, seq_len, num_heads, head_size]
- attn_mask: [batch_size, seq_len, seq_len]
"""
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.matmul(attn_weights, value).transpose(1, 2)
return out
BATCH_SIZES = [1, 16]
SEQ_LENS = [1]
NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80]
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [
torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
CUDA_DEVICES = ["cuda"]
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_forward(
batch_size: int,
seq_len: int,
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: str,
):
current_platform.seed_everything(0)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
q = torch.randn(batch_size, seq_len, num_heads * head_size)
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
scale = 1.0 / head_size**0.5
attn = MultiHeadAttention(num_heads,
head_size,
scale=scale,
num_kv_heads=num_kv_heads)
output = attn(q, k, v)
assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
q = q.reshape(batch_size, seq_len, num_heads, head_size)
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
if num_queries_per_kv > 1:
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)
ref_output = ref_attention(
q,
k,
v,
scale=scale,
).reshape(batch_size, seq_len, num_heads * head_size)
torch.testing.assert_close(output, ref_output)
...@@ -14,8 +14,12 @@ from vllm import _custom_ops as ops ...@@ -14,8 +14,12 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size) fused_topk, moe_align_block_size)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize) marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights)
from vllm.model_executor.models.mixtral import MixtralMoE from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
...@@ -46,8 +50,102 @@ def test_fused_moe( ...@@ -46,8 +50,102 @@ def test_fused_moe(
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk) torch_output = torch_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
torch.testing.assert_close(iterative_output,
torch_output,
atol=2e-2,
rtol=0)
@pytest.mark.parametrize("m", [1, 32, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype, group_size: int, has_zp: bool,
weight_bits: int):
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
if weight_bits == 4:
pack_factor = 2
quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
elif weight_bits == 8:
pack_factor = 1
quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128
w1_ref = w1.clone()
w2_ref = w2.clone()
w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
device="cuda",
dtype=torch.uint8)
w2_qweight = torch.empty((e, k, n // pack_factor),
device="cuda",
dtype=torch.uint8)
w1_scales = torch.empty((e, 2 * n, k // group_size),
device="cuda",
dtype=dtype)
w2_scales = torch.empty((e, k, n // group_size),
device="cuda",
dtype=dtype)
w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
device="cuda",
dtype=torch.uint8)
w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
device="cuda",
dtype=torch.uint8)
for i in range(e * 2):
expert_id = i % e
if i // e == 0:
w, w_ref, w_qweight, w_scales, w_qzeros = \
w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
else:
w, w_ref, w_qweight, w_scales, w_qzeros = \
w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
weight, qweight, scales, qzeros = quantize_weights(
w[expert_id].T, quant_type, group_size, has_zp, False)
weight = weight.T
qweight = qweight.T.contiguous().to(torch.uint8)
scales = scales.T
if has_zp:
qzeros = qzeros.T.contiguous().to(torch.uint8)
if weight_bits == 4:
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
if has_zp:
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
w_ref[expert_id] = weight
w_qweight[expert_id] = qweight
w_scales[expert_id] = scales
if has_zp:
w_qzeros[expert_id] = qzeros
triton_output = fused_moe(a,
w1_qweight,
w2_qweight,
score,
topk,
renormalize=False,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
w1_scale=w1_scales,
w2_scale=w2_scales,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
@pytest.mark.parametrize("dtype", @pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16]) [torch.float32, torch.float16, torch.bfloat16])
@torch.inference_mode() @torch.inference_mode()
......
...@@ -140,6 +140,7 @@ def test_contexted_kv_attention( ...@@ -140,6 +140,7 @@ def test_contexted_kv_attention(
# to V_cache[num_blocks, num_kv_heads, head_size, block_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads, v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous() head_size).permute(0, 2, 3, 1).contiguous()
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring # Warm up the Triton kernel by calling it once before actually measuring
# generation time # generation time
...@@ -155,6 +156,8 @@ def test_contexted_kv_attention( ...@@ -155,6 +156,8 @@ def test_contexted_kv_attention(
b_seq_len, b_seq_len,
b_ctx_len, b_ctx_len,
max_input_len, max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window) sliding_window=sliding_window)
torch.cuda.synchronize() torch.cuda.synchronize()
start_time = time.time() start_time = time.time()
...@@ -170,6 +173,8 @@ def test_contexted_kv_attention( ...@@ -170,6 +173,8 @@ def test_contexted_kv_attention(
b_seq_len, b_seq_len,
b_ctx_len, b_ctx_len,
max_input_len, max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window) sliding_window=sliding_window)
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
...@@ -369,6 +374,7 @@ def test_contexted_kv_attention_alibi( ...@@ -369,6 +374,7 @@ def test_contexted_kv_attention_alibi(
# to V_cache[num_blocks, num_kv_heads, head_size, block_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads, v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous() head_size).permute(0, 2, 3, 1).contiguous()
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring # Warm up the Triton kernel by calling it once before actually measuring
# generation time # generation time
...@@ -384,6 +390,8 @@ def test_contexted_kv_attention_alibi( ...@@ -384,6 +390,8 @@ def test_contexted_kv_attention_alibi(
b_seq_len, b_seq_len,
b_ctx_len, b_ctx_len,
max_input_len, max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes) alibi_slopes=alibi_slopes)
torch.cuda.synchronize() torch.cuda.synchronize()
start_time = time.time() start_time = time.time()
...@@ -399,6 +407,8 @@ def test_contexted_kv_attention_alibi( ...@@ -399,6 +407,8 @@ def test_contexted_kv_attention_alibi(
b_seq_len, b_seq_len,
b_ctx_len, b_ctx_len,
max_input_len, max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes) alibi_slopes=alibi_slopes)
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
......
import pytest
import torch
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
def cdiv(a, b):
return (a + b - 1) // b
@pytest.mark.parametrize("B", [3, 5])
@pytest.mark.parametrize("L", [1027, 1025])
@pytest.mark.parametrize("H_Q", [32])
@pytest.mark.parametrize("H_KV", [32, 8])
@pytest.mark.parametrize("D_QK", [128, 192, 576])
@pytest.mark.parametrize("D_V", [128, 512])
@pytest.mark.parametrize("CACHE_SIZE", [16384])
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
assert CACHE_SIZE % PAGE_SIZE == 0
dtype = torch.bfloat16
seq_len = L # This represents the number of tokens already in the sequence
sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 8
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
req_to_page = torch.randint(0,
CACHE_SIZE // PAGE_SIZE,
(B, num_pages_per_batch, 1),
device="cuda")
req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
1, 1, -1)
req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous()
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
# Call the original implementation.
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
# Page size can be larger than 1.
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
o1 = torch.zeros_like(o)
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o1)
...@@ -39,6 +39,23 @@ def get_8bit_types(): ...@@ -39,6 +39,23 @@ def get_8bit_types():
return types return types
# This test is to check regressions for int8 support on ROCm.
@pytest.mark.parametrize("model_path", [
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="Should only run on ROCm")
def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
max_tokens, num_logprobs):
dtype = "bfloat16"
with vllm_runner(model_path, dtype=dtype) as vllm_model:
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens,
num_logprobs)
@pytest.mark.parametrize("M", [1, 33, 64, 512]) @pytest.mark.parametrize("M", [1, 33, 64, 512])
@pytest.mark.parametrize("N", [256, 971, 20486]) @pytest.mark.parametrize("N", [256, 971, 20486])
@pytest.mark.parametrize("K", [128, 496, 1024]) @pytest.mark.parametrize("K", [128, 496, 1024])
......
...@@ -133,17 +133,19 @@ def test_flashinfer_decode_with_paged_kv( ...@@ -133,17 +133,19 @@ def test_flashinfer_decode_with_paged_kv(
use_tensor_cores=( use_tensor_cores=(
(num_query_heads//num_kv_heads) > 4) (num_query_heads//num_kv_heads) > 4)
) )
wrapper.begin_forward(kv_indptr, wrapper.plan(kv_indptr,
kv_indices, kv_indices,
kv_last_page_lens, kv_last_page_lens,
num_query_heads, num_query_heads,
num_kv_heads, num_kv_heads,
head_size, head_size,
block_size, block_size,
"NONE", "NONE",
data_type=dtype) q_data_type=dtype,
kv_data_type=dtype,
output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap) logits_soft_cap=soft_cap)
output = wrapper.run(query, key_value_cache)
ref_output = ref_paged_attn(query=query, ref_output = ref_paged_attn(query=query,
key_cache=key_cache, key_cache=key_cache,
...@@ -228,7 +230,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], ...@@ -228,7 +230,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD") workspace_buffer, "NHD")
wrapper.begin_forward( wrapper.plan(
qo_indptr, qo_indptr,
kv_indptr, kv_indptr,
kv_indices, kv_indices,
...@@ -237,12 +239,14 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], ...@@ -237,12 +239,14 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
num_kv_heads, num_kv_heads,
head_size, head_size,
block_size, block_size,
q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap,
) )
output = wrapper.forward( output = wrapper.run(
query, query,
key_value_cache, key_value_cache,
logits_soft_cap=soft_cap,
) )
ref_output = ref_paged_attn(query=query, ref_output = ref_paged_attn(query=query,
...@@ -253,7 +257,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], ...@@ -253,7 +257,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
block_tables=block_tables, block_tables=block_tables,
scale=scale, scale=scale,
soft_cap=soft_cap) soft_cap=soft_cap)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
...@@ -332,7 +336,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv( ...@@ -332,7 +336,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD") workspace_buffer, "NHD")
wrapper.begin_forward( wrapper.plan(
qo_indptr, qo_indptr,
kv_indptr, kv_indptr,
kv_indices, kv_indices,
...@@ -341,13 +345,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv( ...@@ -341,13 +345,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
num_kv_heads, num_kv_heads,
head_size, head_size,
block_size, block_size,
q_data_type=dtype,
kv_data_type=kv_cache_dtype,
logits_soft_cap=soft_cap,
) )
output = wrapper.forward(query, output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
kv_cache_fp8,
logits_soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale)
ref_output = ref_paged_attn(query=query, ref_output = ref_paged_attn(query=query,
key_cache=key_cache.squeeze(1), key_cache=key_cache.squeeze(1),
...@@ -360,7 +363,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv( ...@@ -360,7 +363,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
del query del query
del block_tables del block_tables
# verify prefill fp8 # verify prefill fp8
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
...@@ -439,21 +442,18 @@ def test_flashinfer_decode_with_paged_fp8_kv( ...@@ -439,21 +442,18 @@ def test_flashinfer_decode_with_paged_fp8_kv(
wrapper = flashinfer.\ wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=use_tensor_cores) use_tensor_cores=use_tensor_cores)
wrapper.begin_forward(kv_indptr, wrapper.plan(kv_indptr,
kv_indices, kv_indices,
kv_last_page_lens, kv_last_page_lens,
num_query_heads, num_query_heads,
num_kv_heads, num_kv_heads,
head_size, head_size,
block_size, block_size,
"NONE", "NONE",
data_type=dtype, q_data_type=dtype,
q_data_type=dtype) kv_data_type=kv_cache_dtype,
output = wrapper.forward(query, logits_soft_cap=soft_cap)
kv_cache_fp8, output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
logits_soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
......
...@@ -5,7 +5,7 @@ import random ...@@ -5,7 +5,7 @@ import random
import unittest import unittest
from numbers import Number from numbers import Number
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
Union) Type, Union)
import pytest import pytest
import torch import torch
...@@ -13,6 +13,7 @@ from torch._prims_common import TensorLikeType ...@@ -13,6 +13,7 @@ from torch._prims_common import TensorLikeType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
...@@ -790,7 +791,7 @@ def make_block_tables_slot_mapping( ...@@ -790,7 +791,7 @@ def make_block_tables_slot_mapping(
def make_test_metadata( def make_test_metadata(
attn_backend: AttentionBackend, attn_backend: _Backend,
is_prompt: bool, is_prompt: bool,
seq_lens: Optional[List[int]], seq_lens: Optional[List[int]],
decoder_test_params: Optional[PhaseTestParameters], decoder_test_params: Optional[PhaseTestParameters],
...@@ -815,7 +816,7 @@ def make_test_metadata( ...@@ -815,7 +816,7 @@ def make_test_metadata(
Arguments: Arguments:
* attn_backend: Backend for sourcing attention kernels * attn_backend_name: Backend for sourcing attention kernels
* is_prompt: prefill if True, o/w decode * is_prompt: prefill if True, o/w decode
* seq_lens: list of token counts for each sequence * seq_lens: list of token counts for each sequence
* decoder_test_params: decoder self-attention test params; * decoder_test_params: decoder self-attention test params;
...@@ -882,6 +883,8 @@ def make_test_metadata( ...@@ -882,6 +883,8 @@ def make_test_metadata(
# (kv_mmap) # (kv_mmap)
cross_kv_mmap = cross_test_params.kv_mmap cross_kv_mmap = cross_test_params.kv_mmap
attn_backend_obj = make_backend(attn_backend.name)
if is_prompt: if is_prompt:
# Prefill-phase scenario # Prefill-phase scenario
...@@ -902,11 +905,11 @@ def make_test_metadata( ...@@ -902,11 +905,11 @@ def make_test_metadata(
context_lens, context_lens,
encoder_seq_lens, encoder_seq_lens,
device=device) device=device)
return attn_backend_obj.make_metadata(
return attn_backend.make_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
...@@ -952,10 +955,11 @@ def make_test_metadata( ...@@ -952,10 +955,11 @@ def make_test_metadata(
encoder_seq_lens, encoder_seq_lens,
device=device) device=device)
return attn_backend.make_metadata( return attn_backend_obj.make_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping, slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
...@@ -1096,3 +1100,56 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, ...@@ -1096,3 +1100,56 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
kwargs, kwargs,
test_utils=test_utils, test_utils=test_utils,
raise_exception=raise_exception) if cond else {} raise_exception=raise_exception) if cond else {}
# For testing quantized linear kernels
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match the
# the target shape by repeating the data along that dimension (broadcasting)
# , we extend these semantics to say if the extent of a dimension in the
# source shape is not 1 and does not match the target shape we repeat each
# element along that dimension src_shape[dim] // target_shape[dim] times
# example if we have:
# a = [[1, 2], and target_shape = (2, 4)
# [3, 4]]
# then we would expand a to:
# a = [[1, 1, 2, 2],
# [3, 3, 4, 4]]
# NOTE this function this function does not explicitly broadcast dimensions
# with an extent of 1, since this can be done implicitly by pytorch
def group_broadcast(t, shape):
for i, s in enumerate(shape):
if t.shape[i] != s and t.shape[i] != 1:
assert s % t.shape[i] == 0
t = t.unsqueeze(i + 1)\
.expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
.flatten(i, i + 1)
return t
scale_a = group_broadcast(scale_a, a.shape)
scale_b = group_broadcast(scale_b, b.shape)
output = torch.mm((scale_a * a.to(dtype=torch.float32)),
(scale_b * b.to(dtype=torch.float32))).to(out_dtype)
if bias is not None:
output = output + bias
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment