Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
afd0da21
Commit
afd0da21
authored
Feb 03, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.7.1' into v0.7.1-dev
parents
1a11f127
4f4d427a
Changes
587
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1197 additions
and
418 deletions
+1197
-418
tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json
tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json
+0
-90
tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json
tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json
+0
-42
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+13
-7
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+4
-4
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+55
-40
tests/kernels/test_block_fp8.py
tests/kernels/test_block_fp8.py
+14
-11
tests/kernels/test_blocksparse_attention.py
tests/kernels/test_blocksparse_attention.py
+1
-1
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+5
-5
tests/kernels/test_cascade_flash_attn.py
tests/kernels/test_cascade_flash_attn.py
+187
-0
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+177
-110
tests/kernels/test_cutlass_2of4_sparse.py
tests/kernels/test_cutlass_2of4_sparse.py
+214
-0
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+61
-57
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+24
-6
tests/kernels/test_mha_attn.py
tests/kernels/test_mha_attn.py
+126
-0
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+100
-2
tests/kernels/test_prefix_prefill.py
tests/kernels/test_prefix_prefill.py
+10
-0
tests/kernels/test_triton_decode_attention.py
tests/kernels/test_triton_decode_attention.py
+89
-0
tests/kernels/test_triton_scaled_mm.py
tests/kernels/test_triton_scaled_mm.py
+17
-0
tests/kernels/untest_flashinfer.py
tests/kernels/untest_flashinfer.py
+37
-37
tests/kernels/utils.py
tests/kernels/utils.py
+63
-6
No files found.
Too many changes to show.
To preserve performance only
587 of 587+
files are displayed.
Plain diff
Email patch
tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json
deleted
100644 → 0
View file @
1a11f127
{
"model_type"
:
"llama"
,
"kv_cache"
:
{
"dtype"
:
"float8_e4m3fn"
,
"scaling_factor"
:
{
"0"
:
{
"0"
:
0.0230364128947258
,
"1"
:
0.01979283057153225
,
"2"
:
0.0241350457072258
,
"3"
:
0.0308314748108387
,
"4"
:
0.0430733822286129
,
"5"
:
0.0370396226644516
,
"6"
:
0.0306222103536129
,
"7"
:
0.0357491634786129
,
"8"
:
0.0358189195394516
,
"9"
:
0.0443289652466774
,
"10"
:
0.0433175228536129
,
"11"
:
0.0416782945394516
,
"12"
:
0.0366908498108387
,
"13"
:
0.0432477705180645
,
"14"
:
0.0410505048930645
,
"15"
:
0.0457589291036129
,
"16"
:
0.0418526791036129
,
"17"
:
0.0432477705180645
,
"18"
:
0.0469447560608387
,
"19"
:
0.0514787957072258
,
"20"
:
0.0541294664144516
,
"21"
:
0.0587681382894516
,
"22"
:
0.0625
,
"23"
:
0.0585588738322258
,
"24"
:
0.0600237175822258
,
"25"
:
0.0588030144572258
,
"26"
:
0.0531180277466774
,
"27"
:
0.06396484375
,
"28"
:
0.0603027381002903
,
"29"
:
0.0582101047039032
,
"30"
:
0.0625348836183548
,
"31"
:
0.0585588738322258
,
"32"
:
0.0582798570394516
,
"33"
:
0.0575125589966774
,
"34"
:
0.0590820349752903
,
"35"
:
0.0614188089966774
,
"36"
:
0.0631975457072258
,
"37"
:
0.0615931935608387
,
"38"
:
0.0601283498108387
,
"39"
:
0.0571986623108387
,
"40"
:
0.0670340433716774
,
"41"
:
0.0523507259786129
,
"42"
:
0.0547223798930645
,
"43"
:
0.0631975457072258
,
"44"
:
0.0663713738322258
,
"45"
:
0.0603376142680645
,
"46"
:
0.0652204304933548
,
"47"
:
0.0734514519572258
,
"48"
:
0.0693708211183548
,
"49"
:
0.0725446492433548
,
"50"
:
0.0627790242433548
,
"51"
:
0.0691266804933548
,
"52"
:
0.0688825398683548
,
"53"
:
0.068429134786129
,
"54"
:
0.0605119988322258
,
"55"
:
0.0799386203289032
,
"56"
:
0.0853097140789032
,
"57"
:
0.0661969929933548
,
"58"
:
0.0689871683716774
,
"59"
:
0.0724051371216774
,
"60"
:
0.0541643425822258
,
"61"
:
0.0626743882894516
,
"62"
:
0.0628487765789032
,
"63"
:
0.0607212632894516
,
"64"
:
0.0589076466858387
,
"65"
:
0.0451660193502903
,
"66"
:
0.0453055277466774
,
"67"
:
0.0414341539144516
,
"68"
:
0.0385044664144516
,
"69"
:
0.0414341539144516
,
"70"
:
0.0466308631002903
,
"71"
:
0.0399693101644516
,
"72"
:
0.0437011756002903
,
"73"
:
0.0434221550822258
,
"74"
:
0.0428989976644516
,
"75"
:
0.0401785746216774
,
"76"
:
0.0431082621216774
,
"77"
:
0.0484444759786129
,
"78"
:
0.0417829267680645
,
"79"
:
0.0418178029358387
}
}
}
}
\ No newline at end of file
tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json
deleted
100644 → 0
View file @
1a11f127
{
"model_type"
:
"llama"
,
"kv_cache"
:
{
"dtype"
:
"float8_e4m3fn"
,
"scaling_factor"
:
{
"0"
:
{
"0"
:
0.0152239128947258
,
"1"
:
0.0188860222697258
,
"2"
:
0.0354178324341774
,
"3"
:
0.0376674123108387
,
"4"
:
0.0418526791036129
,
"5"
:
0.0433175228536129
,
"6"
:
0.0397600457072258
,
"7"
:
0.0424455925822258
,
"8"
:
0.0415387861430645
,
"9"
:
0.0408412404358387
,
"10"
:
0.0395856611430645
,
"11"
:
0.0377371683716774
,
"12"
:
0.0400739423930645
,
"13"
:
0.040771484375
,
"14"
:
0.0393415205180645
,
"15"
:
0.0369001142680645
,
"16"
:
0.03857421875
,
"17"
:
0.0387486070394516
,
"18"
:
0.0403180830180645
,
"19"
:
0.0396205373108387
,
"20"
:
0.0375627800822258
,
"21"
:
0.0407366082072258
,
"22"
:
0.0432477705180645
,
"23"
:
0.0377022884786129
,
"24"
:
0.0399693101644516
,
"25"
:
0.0374581478536129
,
"26"
:
0.0413295216858387
,
"27"
:
0.0442243330180645
,
"28"
:
0.0424804724752903
,
"29"
:
0.0456891767680645
,
"30"
:
0.0409109964966774
,
"31"
:
0.0482352152466774
}
}
}
}
tests/kernels/test_activation.py
View file @
afd0da21
...
...
@@ -6,8 +6,9 @@ import torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.activation
import
(
FastGELU
,
FatreluAndMul
,
GeluAndMul
,
NewGELU
,
QuickGELU
,
SiluAndMul
)
GeluAndMul
,
MulAndSilu
,
NewGELU
,
QuickGELU
,
SiluAndMul
)
from
vllm.platforms
import
current_platform
from
.allclose_default
import
get_default_atol
,
get_default_rtol
...
...
@@ -21,8 +22,9 @@ CUDA_DEVICES = [
]
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
"silu"
,
"gelu"
,
"gelu_tanh"
,
"fatrelu"
])
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
"silu_and_mul"
,
"mul_and_silu"
,
"gelu"
,
"gelu_tanh"
,
"fatrelu"
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
...
...
@@ -40,9 +42,12 @@ def test_act_and_mul(
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
)
if
activation
==
"silu"
:
if
activation
==
"silu
_and_mul
"
:
layer
=
SiluAndMul
()
fn
=
torch
.
ops
.
_C
.
silu_and_mul
if
activation
==
"mul_and_silu"
:
layer
=
MulAndSilu
()
fn
=
torch
.
ops
.
_C
.
mul_and_silu
elif
activation
==
"gelu"
:
layer
=
GeluAndMul
(
approximate
=
"none"
)
fn
=
torch
.
ops
.
_C
.
gelu_and_mul
...
...
@@ -55,8 +60,9 @@ def test_act_and_mul(
fn
=
torch
.
ops
.
_C
.
fatrelu_and_mul
out
=
layer
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
# The SiLU, GELU and FatReLU implementations are equivalent to the native
# PyTorch implementations, so we can do exact comparison.
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
# equivalent to the native PyTorch implementations, so we can do exact
# comparison.
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
d
=
x
.
shape
[
-
1
]
//
2
...
...
tests/kernels/test_attention.py
View file @
afd0da21
...
...
@@ -31,9 +31,9 @@ NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS
=
[
3
]
# Arbitrary values for testing
NUM_HEADS
=
[(
40
,
40
),
(
64
,
8
)]
# Arbitrary values for testing
#
FlashAttention forward only
support
s
head
dimension at most 128
#
https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES
=
[
64
,
80
,
120
,
256
]
#
This should be sync with get_
support
ed_
head
_sizes() in
#
vllm.attention.ops.paged_attn.PagedAttention
HEAD_SIZES
=
[
32
,
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
]
BLOCK_SIZES
=
[
16
,
32
]
USE_ALIBI
=
[
False
,
True
]
...
...
@@ -182,7 +182,7 @@ def test_paged_attention(
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Using default kv_scale
k_scale
=
v_scale
=
1.0
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
# Call the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
...
...
tests/kernels/test_attention_selector.py
View file @
afd0da21
from
unittest.mock
import
patch
from
unittest.mock
import
Mock
,
patch
import
pytest
import
torch
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm.attention.selector
import
which_attn_to_use
from
vllm.platforms
import
cpu
,
cuda
,
openvino
,
rocm
from
vllm.attention.selector
import
_cached_get_attn_backend
,
get_attn_backend
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.openvino
import
OpenVinoPlatform
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.utils
import
STR_FLASH_ATTN_VAL
,
STR_INVALID_VAL
from
vllm.platforms
import
current_platform
@
pytest
.
fixture
(
autouse
=
True
)
def
clear_cache
():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend
.
cache_clear
()
@
pytest
.
mark
.
parametrize
(
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
,
"OPENVINO"
]
if
not
current_platform
()
else
[
"ROCM_FLASH"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"openvino"
,
"hip"
,
"cuda"
])
...
...
@@ -21,71 +31,76 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable
(
monkeypatch
,
name
)
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
cpu
.
CpuPlatform
()):
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
"TORCH_SDPA"
with
patch
(
"vllm.attention.selector.current_platform"
,
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
"TORCH_SDPA"
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
rocm
.
RocmPlatform
()):
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
"ROCM_FLASH"
with
patch
(
"vllm.attention.selector.current_platform"
,
RocmPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
"ROCM_FLASH"
elif
device
==
"openvino"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
openvino
.
OpenVinoPlatform
()):
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
"OPENVINO"
OpenVinoPlatform
()),
patch
.
dict
(
'sys.modules'
,
{
'openvino'
:
Mock
()}):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
"OPENVINO"
else
:
with
patch
(
"vllm.attention.selector.current_platform"
,
cuda
.
CudaPlatform
()):
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
name
==
name
if
name
in
[
"XFORMERS"
,
"FLASHINFER"
]:
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
name
def
test_flash_attn
(
monkeypatch
):
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to
#
which
_attn_
to_use
#
get
_attn_
backend
override_backend_env_variable
(
monkeypatch
,
STR_FLASH_ATTN_VAL
)
# Unsupported CUDA arch
with
patch
(
"torch.cuda.get_device_capability"
,
return_value
=
(
7
,
5
)):
backend
=
which
_attn_
to_use
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
backend
=
get
_attn_
backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_
name
()
!=
STR_FLASH_ATTN_VAL
# Unsupported data type
backend
=
which
_attn_
to_use
(
16
,
torch
.
float8_e4m3fn
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
backend
=
get
_attn_
backend
(
16
,
torch
.
float8_e4m3fn
,
None
,
16
,
False
)
assert
backend
.
get_
name
()
!=
STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
backend
=
which
_attn_
to_use
(
16
,
torch
.
float16
,
"fp8"
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
backend
=
get
_attn_
backend
(
16
,
torch
.
float16
,
"fp8"
,
16
,
False
)
assert
backend
.
get_
name
()
!=
STR_FLASH_ATTN_VAL
# Unsupported block size
backend
=
which
_attn_
to_use
(
16
,
torch
.
float16
,
None
,
8
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
backend
=
get
_attn_
backend
(
16
,
torch
.
float16
,
None
,
8
,
False
)
assert
backend
.
get_
name
()
!=
STR_FLASH_ATTN_VAL
# flash-attn is not installed
with
patch
.
dict
(
'sys.modules'
,
{
'vllm_flash_attn'
:
None
}):
backend
=
which
_attn_
to_use
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
backend
=
get
_attn_
backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_
name
()
!=
STR_FLASH_ATTN_VAL
# Unsupported head size
backend
=
which
_attn_
to_use
(
17
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
backend
=
get
_attn_
backend
(
17
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_
name
()
!=
STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend
=
which
_attn_
to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
True
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
backend
=
get
_attn_
backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
True
)
assert
backend
.
get_
name
()
!=
STR_FLASH_ATTN_VAL
def
test_invalid_env
(
monkeypatch
):
"""
Throw an exception if the backend nam
e i
s
i
nvalid
."""
"""
Ignore the invalid env variabl
e i
f
i
t is set
."""
override_backend_env_variable
(
monkeypatch
,
STR_INVALID_VAL
)
with
pytest
.
raises
(
ValueError
):
which_attn_to_use
(
16
,
torch
.
float16
,
None
,
16
,
False
)
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
backend
=
get_attn_backend
(
32
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
==
"FLASH_ATTN"
# when block size == 16, backend will fall back to XFORMERS
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
==
"XFORMERS"
tests/kernels/test_block_fp8.py
View file @
afd0da21
...
...
@@ -92,8 +92,10 @@ def native_w8a8_block_fp8_matmul(A,
A
[:,
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
)]
for
i
in
range
(
k_tiles
)
]
B_tiles
=
[[
B
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
),
]
for
i
in
range
(
k_tiles
)
B
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
),
]
for
i
in
range
(
k_tiles
)
]
for
j
in
range
(
n_tiles
)]
C_tiles
=
[
C
[:,
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
)]
for
j
in
range
(
n_tiles
)
...
...
@@ -157,9 +159,9 @@ def setup_cuda():
torch
.
set_default_device
(
"cuda"
)
@
pytest
.
mark
.
parametrize
(
"num_tokens,d,dtype,group_size,seed"
,
itertools
.
product
(
NUM_TOKENS
,
D
,
DTYPES
,
GROUP_SIZE
,
SEEDS
))
@
pytest
.
mark
.
parametrize
(
"num_tokens,d,dtype,group_size,seed"
,
itertools
.
product
(
NUM_TOKENS
,
D
,
DTYPES
,
GROUP_SIZE
,
SEEDS
))
@
torch
.
inference_mode
()
def
test_per_token_group_quant_fp8
(
num_tokens
,
d
,
dtype
,
group_size
,
seed
):
torch
.
manual_seed
(
seed
)
...
...
@@ -174,9 +176,9 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
assert
torch
.
allclose
(
scale
,
ref_scale
)
@
pytest
.
mark
.
parametrize
(
"M,N,K,block_size,out_dtype,seed"
,
itertools
.
product
(
M
,
N
,
K
,
BLOCK_SIZE
,
OUT_DTYPES
,
SEEDS
))
@
pytest
.
mark
.
parametrize
(
"
M,N,K,
block_size,out_dtype,seed"
,
itertools
.
product
(
M
,
N
,
K
,
BLOCK_SIZE
,
OUT_DTYPES
,
SEEDS
))
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_matmul
(
M
,
N
,
K
,
block_size
,
out_dtype
,
seed
):
torch
.
manual_seed
(
seed
)
...
...
@@ -207,9 +209,10 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
assert
rel_diff
<
0.001
@
pytest
.
mark
.
parametrize
(
"M,N,K,E,topk,block_size,dtype,seed"
,
itertools
.
product
(
M_moe
,
N_moe
,
K_moe
,
E
,
TOP_KS
,
BLOCK_SIZE
,
DTYPES
,
SEEDS
))
@
pytest
.
mark
.
parametrize
(
"M,N,K,E,topk,block_size,dtype,seed"
,
itertools
.
product
(
M_moe
,
N_moe
,
K_moe
,
E
,
TOP_KS
,
BLOCK_SIZE
,
DTYPES
,
SEEDS
))
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_fused_moe
(
M
,
N
,
K
,
E
,
topk
,
block_size
,
dtype
,
seed
):
torch
.
manual_seed
(
seed
)
...
...
tests/kernels/test_blocksparse_attention.py
View file @
afd0da21
...
...
@@ -210,7 +210,7 @@ def test_paged_attention(
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Using default kv_scale
k_scale
=
v_scale
=
1.0
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
tp_rank
=
0
# Call the paged attention kernel.
...
...
tests/kernels/test_cache.py
View file @
afd0da21
...
...
@@ -161,7 +161,7 @@ def test_reshape_and_cache(
cloned_value_cache
=
value_cache
.
clone
()
# Using default kv_scale
k_scale
=
v_scale
=
1.0
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
...
...
@@ -259,8 +259,8 @@ def test_reshape_and_cache_flash(
del
key_caches
del
value_caches
k_scale
=
key
.
amax
()
.
item
()
/
256
v_scale
=
value
.
amax
()
.
item
()
/
256
k_scale
=
(
key
.
amax
()
/
256
.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
256
.0
).
to
(
torch
.
float32
)
# Clone the KV caches.
if
kv_cache_dtype
==
"fp8"
:
...
...
@@ -285,12 +285,12 @@ def test_reshape_and_cache_flash(
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
,
k_scale
,
k_scale
.
item
()
,
kv_dtype
=
kv_cache_dtype
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
,
v_scale
,
v_scale
.
item
()
,
kv_dtype
=
kv_cache_dtype
)
# Run the reference implementation.
...
...
tests/kernels/test_cascade_flash_attn.py
0 → 100755
View file @
afd0da21
from
typing
import
List
,
Optional
,
Tuple
import
pytest
import
torch
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.flash_attn
import
(
cascade_attention
,
merge_attn_states
)
from
vllm.vllm_flash_attn
import
(
fa_version_unsupported_reason
,
flash_attn_varlen_func
,
is_fa_version_supported
)
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
),
(
16
,
2
)]
HEAD_SIZES
=
[
128
,
192
,
256
]
BLOCK_SIZES
=
[
16
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
39
,
16912
])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
torch
.
inference_mode
()
def
test_merge_kernel
(
num_tokens
:
int
,
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
):
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
# Prepare inputs.
prefix_output
=
torch
.
randn
(
num_tokens
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
suffix_output
=
torch
.
randn
(
num_tokens
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
prefix_lse
=
torch
.
randn
(
num_query_heads
,
num_tokens
,
dtype
=
torch
.
float32
)
suffix_lse
=
torch
.
randn
(
num_query_heads
,
num_tokens
,
dtype
=
torch
.
float32
)
# Run the kernel.
output
=
torch
.
empty
(
num_tokens
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
)
# Reference implementation.
max_lse
=
torch
.
maximum
(
prefix_lse
,
suffix_lse
)
p_lse
=
torch
.
exp
(
prefix_lse
-
max_lse
)
s_lse
=
torch
.
exp
(
suffix_lse
-
max_lse
)
p_scale
=
p_lse
/
(
p_lse
+
s_lse
)
s_scale
=
s_lse
/
(
p_lse
+
s_lse
)
p_scale
=
p_scale
.
transpose
(
0
,
1
).
unsqueeze
(
2
)
s_scale
=
s_scale
.
transpose
(
0
,
1
).
unsqueeze
(
2
)
ref_output
=
p_scale
*
prefix_output
+
s_scale
*
suffix_output
ref_output
=
ref_output
.
to
(
dtype
)
# Compare the results.
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
)
CASES
=
[
# Case 1. A general case.
([(
129
,
871
),
(
18
,
280
),
(
37
,
988
),
(
1023
,
2304
),
(
1
,
257
)],
256
),
# Case 2. Flash-decoding case.
([(
1
,
1023
),
(
1
,
879
),
(
1
,
778
),
(
1
,
1777
)]
*
100
,
512
),
]
@
pytest
.
mark
.
parametrize
(
"seq_lens_and_common_prefix"
,
CASES
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
50
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
torch
.
inference_mode
()
def
test_cascade
(
seq_lens_and_common_prefix
:
Tuple
[
List
[
Tuple
[
int
,
int
]],
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
fa_version
:
int
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
if
not
is_fa_version_supported
(
fa_version
):
pytest
.
skip
(
f
"Flash attention version
{
fa_version
}
not supported due "
f
"to:
\"
{
fa_version_unsupported_reason
(
fa_version
)
}
\"
"
)
current_platform
.
seed_everything
(
0
)
window_size
=
(
-
1
,
-
1
)
scale
=
head_size
**-
0.5
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
key_cache
=
torch
.
randn
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
value_cache
=
torch
.
randn_like
(
key_cache
)
seq_lens
,
common_prefix_len
=
seq_lens_and_common_prefix
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
max_query_len
=
max
(
query_lens
)
max_kv_len
=
max
(
kv_lens
)
total_num_query_tokens
=
sum
(
query_lens
)
query
=
torch
.
randn
(
total_num_query_tokens
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
assert
common_prefix_len
>
0
assert
common_prefix_len
%
block_size
==
0
num_common_kv_blocks
=
common_prefix_len
//
block_size
# Make sure the first `num_common_kv_blocks` blocks are the same.
block_tables
[:,
:
num_common_kv_blocks
]
=
\
block_tables
[
0
,
:
num_common_kv_blocks
]
# Run the regular attention.
ref_output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
seqused_k
=
kv_lens_tensor
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
)
# Run cascade attention.
assert
all
(
common_prefix_len
<
kv_len
for
kv_len
in
kv_lens
)
cu_prefix_query_lens
=
torch
.
tensor
([
0
,
total_num_query_tokens
],
dtype
=
torch
.
int32
)
prefix_kv_lens
=
torch
.
tensor
([
common_prefix_len
],
dtype
=
torch
.
int32
)
suffix_kv_lens
=
kv_lens_tensor
-
common_prefix_len
output
=
torch
.
empty_like
(
query
)
cascade_attention
(
output
=
output
,
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
cu_query_lens
=
cu_query_lens
,
max_query_len
=
max_query_len
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
max_kv_len
=
max_kv_len
,
softmax_scale
=
scale
,
alibi_slopes
=
None
,
sliding_window
=
window_size
,
logits_soft_cap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
block_table
=
block_tables
,
common_prefix_len
=
common_prefix_len
,
fa_version
=
fa_version
,
)
# Compare the results.
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
)
tests/kernels/test_cutlass.py
View file @
afd0da21
...
...
@@ -2,7 +2,7 @@
Run `pytest tests/kernels/test_cutlass.py`.
"""
from
typing
import
Optional
,
Type
from
typing
import
Type
,
Optional
import
pytest
import
torch
...
...
@@ -10,6 +10,9 @@ import torch
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
.utils
import
baseline_scaled_mm
,
to_fp8
,
to_int8
MNK_FACTORS
=
[
(
1
,
256
,
128
),
...
...
@@ -37,20 +40,15 @@ CUDA_DEVICES = [
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
# -1 means full extent in that dimension
TENSORWISE_GROUP_SHAPE
=
(
-
1
,
-
1
)
PER_TOKEN_GROUP_SHAPE
=
(
1
,
-
1
)
PER_OUT_CH_GROUP_SHAPE
=
(
-
1
,
1
)
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
def
to_fp8
(
tensor
:
torch
.
Tensor
):
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
torch
.
round
(
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
to_int8
(
tensor
:
torch
.
Tensor
):
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
def
rand_int8
(
shape
:
tuple
,
device
:
str
=
"cuda"
):
return
to_int8
(
torch
.
rand
(
shape
,
device
=
device
)
*
255
-
128
)
...
...
@@ -66,14 +64,22 @@ def baseline_scaled_mm(a: torch.Tensor,
if
bias
is
not
None
:
output
=
output
+
bias
return
output
def
group_scale_helper
(
shape
,
group_shape
):
return
[
shape
[
i
]
if
s
<
0
else
s
for
i
,
s
in
enumerate
(
group_shape
)]
def
scale_shape
(
shape
,
group_shape
):
assert
len
(
shape
)
==
len
(
group_shape
)
group_shape
=
group_scale_helper
(
shape
,
group_shape
)
return
tuple
(
cdiv
(
shape
[
i
],
group_shape
[
i
])
for
i
in
range
(
len
(
group_shape
)))
def
cutlass_fp8_gemm_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
a_scale_group_shape
:
tuple
,
b_scale_group_shape
:
tuple
,
use_bias
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
...
...
@@ -82,13 +88,17 @@ def cutlass_fp8_gemm_helper(m: int,
a
=
to_fp8
(
torch
.
randn
((
m
,
k
),
device
=
device
))
b
=
to_fp8
(
torch
.
randn
((
n
,
k
),
device
=
device
).
t
())
m_a_scales
=
m
if
per_token_act_quant
else
1
n_b_scales
=
n
if
per_out_channel_weight_quant
else
1
a_scales_shape
=
scale_shape
(
a
.
shape
,
a_scale_group_shape
)
b_scales_shape
=
scale_shape
(
b
.
shape
,
b_scale_group_shape
)
scale_a
=
(
torch
.
randn
(
a_scales_shape
,
device
=
device
,
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
(
b_scales_shape
,
device
=
device
,
dtype
=
torch
.
float32
))
# make scales M-major for blockwise quant, doesn't affect 1D scales
scale_a
=
scale_a
.
t
().
contiguous
().
t
()
# make scales K-major for blockwise quant, doesn't affect 1D scales
scale_b
=
scale_b
.
t
().
contiguous
().
t
()
scale_a
=
(
torch
.
randn
((
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
((
1
,
n_b_scales
),
device
=
device
,
dtype
=
torch
.
float32
))
if
use_bias
:
bias
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
out_dtype
)
*
10
else
:
...
...
@@ -106,8 +116,8 @@ def cutlass_fp8_gemm_helper(m: int,
def
cutlass_int8_gemm_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
a_scale_group_shape
:
tuple
,
b_scale_group_shape
:
tuple
,
use_bias
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
...
...
@@ -116,13 +126,11 @@ def cutlass_int8_gemm_helper(m: int,
a
=
to_int8
(
torch
.
randn
((
m
,
k
),
device
=
device
)
*
5
)
b
=
to_int8
(
torch
.
randn
((
n
,
k
),
device
=
device
).
t
()
*
5
)
m_
a_scales
=
m
if
per_token_act_quant
else
1
n_
b_scales
=
n
if
per_out_channel_weight_quant
else
1
a_scales
_shape
=
scale_shape
(
a
.
shape
,
a_scale_group_shape
)
b_scales
_shape
=
scale_shape
(
b
.
shape
,
b_scale_group_shape
)
scale_a
=
(
torch
.
randn
((
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
((
1
,
n_b_scales
),
device
=
device
,
dtype
=
torch
.
float32
))
scale_a
=
(
torch
.
randn
(
a_scales_shape
,
device
=
device
,
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
(
b_scales_shape
,
device
=
device
,
dtype
=
torch
.
float32
))
if
use_bias
:
bias
=
torch
.
rand
((
n
,
),
device
=
device
,
dtype
=
out_dtype
)
*
10
...
...
@@ -139,85 +147,139 @@ def cutlass_int8_gemm_helper(m: int,
# @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
# @pytest.mark.parametrize("per_act_token", [True, False])
# @pytest.mark.parametrize("per_out_ch", [True, False])
# @pytest.mark.parametrize("a_scale_group_shape",
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
# reason="FP8 is not supported on this GPU type.")
# def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
# per_out_ch: bool, use_bias: bool):
# cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
# def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape,
# b_scale_group_shape, use_bias: bool):
# cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
# use_bias)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape,b_scale_group_shape"
,
[((
1
,
128
),
(
128
,
128
))])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
False
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
90
),
reason
=
"FP8 blockwise is not supported on this GPU type."
)
def
test_cutlass_fp8_blockwise_scale_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
:
bool
):
if
k
%
b_scale_group_shape
[
0
]
!=
0
or
n
%
b_scale_group_shape
[
1
]
!=
0
:
return
if
m
%
a_scale_group_shape
[
0
]
!=
0
or
k
%
a_scale_group_shape
[
1
]
!=
0
:
return
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape"
,
[
PER_TOKEN_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"b_scale_group_shape"
,
[
PER_OUT_CH_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_bias
)
def
test_cutlass_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
:
bool
):
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
float16
])
#torch.bfloat16,
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape"
,
[
PER_TOKEN_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"b_scale_group_shape"
,
[
PER_OUT_CH_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_int8_gemm_output_dtype
(
a_scale_group_shape
,
b_scale_group_shape
,
out_dtype
:
Type
[
torch
.
dtype
],
use_bias
:
bool
):
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
,
out_dtype
=
out_dtype
)
# @pytest.mark.parametrize("per_act_token", [True, False])
# @pytest.mark.parametrize("per_out_ch", [True, False])
# @pytest.mark.parametrize("a_scale_group_shape",
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
# reason="FP8 is not supported on this GPU type.")
# def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
# def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
# b_scale_group_shape,
# out_dtype: Type[torch.dtype],
# use_bias: bool):
# cutlass_fp8_gemm_helper(512,
# 512,
# 512,
#
per_act_token
,
#
per_out_ch
,
#
a_scale_group_shape
,
#
b_scale_group_shape
,
# use_bias,
# out_dtype=out_dtype)
# @pytest.mark.parametrize("per_act_token", [True, False])
# @pytest.mark.parametrize("per_out_ch", [True, False])
# @pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
# [((1, 128), (128, 128))])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
# @pytest.mark.parametrize("use_bias", [False])
# @pytest.mark.skipif(not current_platform.has_device_capability(90),
# reason="FP8 blockwise is not supported on this GPU type.")
# def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
# b_scale_group_shape,
# out_dtype: Type[torch.dtype],
# use_bias: bool):
# cutlass_fp8_gemm_helper(512,
# 512,
# 512,
# a_scale_group_shape,
# b_scale_group_shape,
# use_bias,
# out_dtype=out_dtype)
# @pytest.mark.parametrize("a_scale_group_shape",
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
# reason="FP8 is not supported on this GPU type.")
# def test_cutlass_fp8_gemm_devices(
per_act_token: bool, per_out_ch: bool
,
# def test_cutlass_fp8_gemm_devices(
a_scale_group_shape, b_scale_group_shape
,
# use_bias: bool, device: str):
# cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
# torch.bfloat16, device)
# cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape,
# b_scale_group_shape, use_bias, torch.bfloat16,
# device)
# @pytest.mark.parametrize("per_act_token", [True, False])
# @pytest.mark.parametrize("per_out_ch", [True, False])
# @pytest.mark.parametrize("a_scale_group_shape",
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# use_bias: bool, device: str):
# cutlass_int8_gemm_helper(512,
# 512,
# 512,
# per_act_token,
# per_out_ch,
# use_bias,
# out_dtype=torch.bfloat16,
# device=device)
# def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
use_bias
:
bool
,
device
:
str
):
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
,
out_dtype
=
torch
.
bfloat16
,
device
=
device
)
# For the following two tests:
...
...
@@ -225,28 +287,32 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
# @pytest.mark.parametrize("per_act_token", [True, False])
# @pytest.mark.parametrize("per_out_ch", [True, False])
# @pytest.mark.parametrize("a_scale_group_shape",
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
# reason="FP8 is not supported on this GPU type.")
# def test_cutlass_fp8_gemm_m_sweep(
per_act_token: bool, per_out_ch: bool
,
# def test_cutlass_fp8_gemm_m_sweep(
a_scale_group_shape, b_scale_group_shape
,
# use_bias: bool):
# for nk in range(32, 128, 32):
# for m in range(1, 128):
# cutlass_fp8_gemm_helper(m, nk, nk,
per_act_token, per_out_ch
,
# use_bias)
# cutlass_fp8_gemm_helper(m, nk, nk,
a_scale_group_shape
,
#
b_scale_group_shape,
use_bias)
# @pytest.mark.parametrize("per_act_token", [True, False])
# @pytest.mark.parametrize("per_out_ch", [True, False])
# @pytest.mark.parametrize("a_scale_group_shape",
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("use_bias", [True, False])
# def test_cutlass_int8_gemm_m_sweep(
per_act_token: bool, per_out_ch: bool
,
#
use_bias: bool):
#
for nk in range(32, 128, 32):
#
for m in range(1, 128):
#
cutlass_int8_gemm_helper(m, nk, nk,
per_act_token, per_out_ch
,
#
use_bias)
# def test_cutlass_int8_gemm_m_sweep(
a_scale_group_shape, b_scale_group_shape
,
use_bias
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
cutlass_int8_gemm_helper
(
m
,
nk
,
nk
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
)
# @pytest.mark.parametrize("m", [32, 64, 128])
...
...
@@ -304,38 +370,39 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
# @pytest.mark.parametrize("n", [16, 32, 64])
# @pytest.mark.parametrize("k", [64, 128, 256])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
# @pytest.mark.skip
# def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
# out_dtype: torch.dtype):
# # Currently, the test is failing because folding azp into
# # 16-bit bias loses too much precision
# scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
# scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
# aq_i8 = rand_int8((m, k))
# bq_i8 = rand_int8((n, k)).t()
# aq_i32 = aq_i8.to(dtype=torch.int32)
# bq_i32 = bq_i8.to(dtype=torch.int32)
# aq_f32 = aq_i8.to(dtype=torch.float32)
# bq_f32 = bq_i8.to(dtype=torch.float32)
# b_dq = scale_b * bq_f32
# azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
# azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
# azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
# a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
# torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
# baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("azp_per_token", [True, False])
# def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
# use_bias: bool, azp_per_token: bool):
# m_azp = m if azp_per_token else 1
# scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
# scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
# aq_i8 = rand_int8((m, k))
# aq_i32 = aq_i8.to(dtype=torch.int32)
# aq_f32 = aq_i8.to(dtype=torch.float32)
# bq_i8 = rand_int8((n, k)).t()
# bq_i32 = bq_i8.to(dtype=torch.int32)
# bq_f32 = bq_i8.to(dtype=torch.float32)
# b_dq = scale_b * bq_f32
# azp_a = torch.rand(
# (m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
# azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
# azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
# a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
# torch.testing.assert_close(a_dq,
# scale_a * aq_f32 - azp_a,
# rtol=1e-4,
# atol=1e-3)
# if use_bias:
# bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
# else:
# bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
# J = torch.ones((1, k), device="cuda", dtype=torch.float32)
# azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
# assert azp_bias.shape == (1, n)
# assert azp_bias[0, :].shape == (n, )
# baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
# (aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
...
...
tests/kernels/test_
semi_structured
.py
→
tests/kernels/test_
cutlass_2of4_sparse
.py
View file @
afd0da21
...
...
@@ -2,16 +2,19 @@
Run `pytest tests/kernels/test_semi_structured.py`.
"""
from
typing
import
Optional
,
Tuple
,
Type
from
typing
import
Tuple
,
Type
import
pytest
import
torch
import
torch.nn.functional
as
F
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
sparse_cutlass_supported
)
from
vllm.platforms
import
current_platform
from
.utils
import
baseline_scaled_mm
,
to_fp8
,
to_int8
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
...
...
@@ -20,20 +23,6 @@ capability = current_platform.get_device_capability()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
def
to_fp8
(
tensor
:
torch
.
Tensor
):
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
torch
.
round
(
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
to_int8
(
tensor
:
torch
.
Tensor
):
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
def
rand_int8
(
shape
:
tuple
,
device
:
str
=
"cuda"
):
return
to_int8
(
torch
.
rand
(
shape
,
device
=
device
)
*
255
-
128
)
def
to_bf16
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
tensor
.
to
(
dtype
=
torch
.
bfloat16
)
...
...
@@ -90,22 +79,8 @@ def make_rand_sparse_tensors(
return
b_compressed
,
e
,
a
,
b
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
output
=
(
scale_a
*
(
scale_b
*
(
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))))).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse
FP8
is not
yet
supported on this GPU type."
)
reason
=
"Sparse
CUTLASS
is not supported on this GPU type."
)
# Test working with a subset of A and B for sparse matmul
def
test_cutlass_sparse_subset
():
...
...
@@ -132,3 +107,108 @@ def test_cutlass_sparse_subset():
out_dtype
=
torch
.
bfloat16
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
MNK_FACTORS
=
[
(
1
,
256
,
128
),
(
1
,
16384
,
1024
),
(
1
,
24576
,
512
),
(
16
,
256
,
512
),
(
16
,
16384
,
128
),
(
16
,
24576
,
4096
),
(
32
,
8192
,
4096
),
(
32
,
16384
,
4096
),
(
33
,
1024
,
1024
),
(
33
,
8192
,
128
),
(
64
,
2048
,
512
),
(
64
,
16384
,
1024
),
(
100
,
8192
,
512
),
(
128
,
32768
,
4096
),
(
256
,
4096
,
4096
),
(
512
,
256
,
1024
),
(
512
,
8192
,
4096
),
(
512
,
16384
,
128
),
(
512
,
24576
,
128
),
]
# Test working with a subset of A and B for sparse matmul
@
pytest
.
mark
.
skip
(
reason
=
"2of4 sparse w16a16 CUTLASS produces bad output."
)
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse CUTLASS is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"m, k, n"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
def
test_cutlass_sparse_gemm
(
m
:
int
,
k
:
int
,
n
:
int
,
dtype
:
Type
[
torch
.
dtype
]):
# Create tensors
b_comp
,
e
,
a
,
b
=
make_rand_sparse_tensors
(
dtype
,
m
,
n
,
k
)
scale_a
=
torch
.
ones
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
ones
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_comp
,
e
,
scale_a
,
scale_b
,
out_dtype
=
dtype
)
baseline
=
F
.
linear
(
a
,
b
.
T
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
1e-2
)
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse CUTLASS is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"m, k, n"
,
MNK_FACTORS
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_sparse_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
):
# Create tensors
b_comp
,
e
,
a
,
b
=
make_rand_sparse_tensors
(
torch
.
float8_e4m3fn
,
m
,
n
,
k
)
scale_a
=
(
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
))
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_comp
,
e
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e0
,
atol
=
2e0
)
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse CUTLASS is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"m,k,n"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_sparse_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
# Create tensors
b_comp
,
e
,
a
,
b
=
make_rand_sparse_tensors
(
torch
.
int8
,
m
,
n
,
k
)
scale_a
=
(
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
))
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_comp
,
e
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e0
,
atol
=
2e0
)
tests/kernels/test_encoder_decoder_attn.py
View file @
afd0da21
...
...
@@ -13,8 +13,7 @@ import pytest
import
torch
from
tests.kernels.utils
import
*
from
vllm.attention
import
(
Attention
,
AttentionBackend
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention.backends.utils
import
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from
vllm.attention.selector
import
(
_Backend
,
_cached_get_attn_backend
,
global_force_attn_backend_context_manager
)
...
...
@@ -64,6 +63,7 @@ class TestPoint(NamedTuple):
max_dec_seq_len
:
int
max_enc_seq_len
:
int
num_blocks
:
int
attn_type
:
AttentionType
class
TestResources
(
NamedTuple
):
...
...
@@ -96,7 +96,6 @@ class TestResources(NamedTuple):
'''
scale
:
float
attn_backend
:
AttentionBackend
attn
:
Attention
kv_cache
:
torch
.
Tensor
...
...
@@ -129,26 +128,33 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
'''
scale
=
float
(
1.0
/
(
test_pt
.
head_size
**
0.5
))
attn_backend
=
make_backend
(
test_pt
.
backend_name
)
attn
=
Attention
(
test_pt
.
num_heads
,
test_pt
.
head_size
,
scale
=
scale
,
prefix
=
f
"
{
test_pt
.
attn_type
}
"
,
attn_type
=
test_pt
.
attn_type
,
)
if
test_pt
.
num_blocks
is
None
or
test_pt
.
num_heads
is
None
:
# Caller does not require a KV cache
return
TestResources
(
scale
,
attn_backend
,
attn
,
scale
,
attn
,
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
CUDA_DEVICE
))
# Construct KV cache
kv_cache
=
make_kv_cache
(
test_pt
.
num_blocks
,
test_pt
.
num_heads
,
test_pt
.
head_size
,
test_pt
.
block_size
,
device
=
CUDA_DEVICE
,
backend
=
test_pt
.
backend_name
)
return
TestResources
(
scale
,
attn_backend
,
attn
,
kv_cache
)
if
test_pt
.
attn_type
in
(
AttentionType
.
DECODER
,
AttentionType
.
ENCODER_DECODER
):
kv_cache
=
make_kv_cache
(
test_pt
.
num_blocks
,
test_pt
.
num_heads
,
test_pt
.
head_size
,
test_pt
.
block_size
,
device
=
CUDA_DEVICE
,
backend
=
test_pt
.
backend_name
)
else
:
kv_cache
=
torch
.
tensor
([])
attn
.
kv_cache
=
[
kv_cache
]
return
TestResources
(
scale
,
attn
,
kv_cache
)
def
_encoder_attn_setup
(
...
...
@@ -193,6 +199,7 @@ def _encoder_attn_setup(
_
,
max_q_seq_len
,
_
,
_
,
)
=
test_pt
scale
=
test_rsrcs
.
scale
...
...
@@ -301,6 +308,7 @@ def _decoder_attn_setup(
max_q_seq_len
,
_
,
_
,
_
,
)
=
test_pt
scale
=
test_rsrcs
.
scale
...
...
@@ -488,6 +496,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
max_decoder_seq_len
,
max_encoder_seq_len
,
_
,
_
,
)
=
test_pt
scale
=
test_rsrcs
.
scale
...
...
@@ -622,7 +631,6 @@ def _run_encoder_attention_test(
& attn_metadata
'''
assert
attn_metadata
.
num_decode_tokens
==
0
attn_type
=
AttentionType
.
ENCODER
packed_qkv
=
encoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
with
set_forward_context
(
attn_metadata
,
vllm_config
):
...
...
@@ -635,14 +643,11 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
packed_qkv
.
query
.
device
),
attn_metadata
,
attn_type
=
attn_type
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
packed_qkv
.
query
.
device
),
attn_metadata
)
def
_run_decoder_self_attention_test
(
...
...
@@ -675,7 +680,6 @@ def _run_decoder_self_attention_test(
* Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata
'''
attn_type
=
AttentionType
.
DECODER
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
...
...
@@ -690,12 +694,8 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
kv_cache
,
attn_metadata
,
attn_type
=
attn_type
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
kv_cache
,
attn_metadata
)
def
_run_encoder_decoder_cross_attention_test
(
...
...
@@ -742,7 +742,6 @@ def _run_encoder_decoder_cross_attention_test(
'''
assert
decoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
attn_type
=
AttentionType
.
ENCODER_DECODER
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
if
cross_test_params
is
None
:
...
...
@@ -762,12 +761,8 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
key
,
value
,
kv_cache
,
attn_metadata
,
attn_type
=
attn_type
)
return
attn
.
forward
(
reshaped_query
,
key
,
value
,
kv_cache
,
attn_metadata
)
@
pytest
.
fixture
(
autouse
=
True
)
...
...
@@ -839,7 +834,7 @@ def test_encoder_only(
# is not part of this test
test_pt
=
TestPoint
(
num_heads
,
head_size
,
attn_backend
.
name
,
batch_size
,
block_size
,
max_dec_seq_len
,
max_enc_seq_len
,
4096
)
max_enc_seq_len
,
4096
,
AttentionType
.
ENCODER
)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
...
...
@@ -855,7 +850,7 @@ def test_encoder_only(
# Shared prefill metadata structure
prephase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
test_rsrcs
.
attn_backend
,
attn_backend
,
True
,
None
,
decoder_test_params
=
None
,
...
...
@@ -961,20 +956,29 @@ def test_e2e_enc_dec_attn(
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
test_pt
=
TestPoint
(
num_heads
,
head_size
,
attn_backend
.
name
,
batch_size
,
block_size
,
max_dec_seq_len
,
max_enc_seq_len
,
4096
)
enc_test_pt
=
TestPoint
(
num_heads
,
head_size
,
attn_backend
.
name
,
batch_size
,
block_size
,
max_dec_seq_len
,
max_enc_seq_len
,
4096
,
AttentionType
.
ENCODER
)
enc_dec_test_pt
=
TestPoint
(
num_heads
,
head_size
,
attn_backend
.
name
,
batch_size
,
block_size
,
max_dec_seq_len
,
max_enc_seq_len
,
4096
,
AttentionType
.
ENCODER_DECODER
)
dec_test_pt
=
TestPoint
(
num_heads
,
head_size
,
attn_backend
.
name
,
batch_size
,
block_size
,
max_dec_seq_len
,
max_enc_seq_len
,
4096
,
AttentionType
.
DECODER
)
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
vllm_config
=
VllmConfig
()
with
set_current_vllm_config
(
vllm_config
):
test_rsrcs
=
_make_test_resources
(
test_pt
)
enc_test_rsrcs
=
_make_test_resources
(
enc_test_pt
)
enc_dec_test_rsrcs
=
_make_test_resources
(
enc_dec_test_pt
)
dec_test_rsrcs
=
_make_test_resources
(
dec_test_pt
)
# Construct encoder attention test params (only used
# during prefill)
enc_test_params
=
_encoder_attn_setup
(
test_pt
,
test_rsrcs
)
enc_test_params
=
_encoder_attn_setup
(
enc_
test_pt
,
enc_
test_rsrcs
)
# Construct Decoder self-attention prefill-phase & decode-phase
# test params, including query/key/value tensors, decoder self-attention
...
...
@@ -987,7 +991,7 @@ def test_e2e_enc_dec_attn(
prephase_dec_test_params
,
decphase_dec_test_params
,
cross_block_base_addr
,
)
=
_decoder_attn_setup
(
test_pt
,
test_rsrcs
)
)
=
_decoder_attn_setup
(
dec_
test_pt
,
dec_
test_rsrcs
)
# Construct encoder/decoder cross-attention prefill-phase
# & decode-phase test params, including key/value tensors,
...
...
@@ -1000,14 +1004,14 @@ def test_e2e_enc_dec_attn(
dec_qkv
,
enc_test_params
,
prephase_dec_test_params
,
test_pt
,
test_rsrcs
,
enc_dec_
test_pt
,
enc_dec_
test_rsrcs
,
block_base_addr
=
cross_block_base_addr
)
# Shared prefill metadata structure
assert
prephase_dec_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
prephase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
test_rsrcs
.
attn_backend
,
attn_backend
,
True
,
prephase_dec_test_params
.
packed_qkvo
.
packed_qkv
.
q_seq_lens
,
decoder_test_params
=
prephase_dec_test_params
,
...
...
@@ -1017,10 +1021,10 @@ def test_e2e_enc_dec_attn(
# PREFILL: encoder attention
enc_pckd_act_out
=
_run_encoder_attention_test
(
test_rsrcs
.
attn
,
enc_pckd_act_out
=
_run_encoder_attention_test
(
enc_
test_rsrcs
.
attn
,
enc_test_params
,
prephase_attn_metadata
,
test_pt
=
test_pt
,
test_pt
=
enc_
test_pt
,
vllm_config
=
vllm_config
)
# - Is encoder attention result correct?
...
...
@@ -1030,10 +1034,10 @@ def test_e2e_enc_dec_attn(
# PREFILL: decoder self-attention test
prephase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
test_rsrcs
,
dec_
test_rsrcs
,
prephase_dec_test_params
,
prephase_attn_metadata
,
test_pt
=
test_pt
,
test_pt
=
dec_
test_pt
,
vllm_config
=
vllm_config
)
# - Is prefill decoder self-attention correct?
...
...
@@ -1044,11 +1048,11 @@ def test_e2e_enc_dec_attn(
# PREFILL: encoder/decoder cross-attention test
prephase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
test_rsrcs
,
enc_dec_
test_rsrcs
,
prephase_dec_test_params
,
prephase_cross_test_params
,
prephase_attn_metadata
,
test_pt
=
test_pt
,
test_pt
=
enc_dec_
test_pt
,
vllm_config
=
vllm_config
)
# - Is prefill encoder/decoder cross-attention correct?
...
...
@@ -1059,7 +1063,7 @@ def test_e2e_enc_dec_attn(
# DECODE: build decode-phase attention metadata
decphase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
test_rsrcs
.
attn_backend
,
attn_backend
,
False
,
dec_qkv
.
q_seq_lens
,
decoder_test_params
=
decphase_dec_test_params
,
...
...
@@ -1070,10 +1074,10 @@ def test_e2e_enc_dec_attn(
# DECODE: decoder self-attention test
decphase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
test_rsrcs
,
dec_
test_rsrcs
,
decphase_dec_test_params
,
decphase_attn_metadata
,
test_pt
=
test_pt
,
test_pt
=
dec_
test_pt
,
vllm_config
=
vllm_config
)
# - Is decode-phase decoder self-attention correct?
...
...
@@ -1084,11 +1088,11 @@ def test_e2e_enc_dec_attn(
# DECODE: encoder/decoder cross-attention test
decphase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
test_rsrcs
,
enc_dec_
test_rsrcs
,
decphase_dec_test_params
,
None
,
decphase_attn_metadata
,
test_pt
=
test_pt
,
test_pt
=
enc_dec_
test_pt
,
vllm_config
=
vllm_config
)
# - Is decode-phase encoder/decoder cross-attention correct?
...
...
tests/kernels/test_flash_attn.py
View file @
afd0da21
...
...
@@ -5,11 +5,14 @@ import torch
from
vllm.platforms
import
current_platform
if
current_platform
():
import
flash_attn
else
:
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
from
vllm.vllm_flash_attn
import
(
fa_version_unsupported_reason
,
flash_attn_varlen_func
,
flash_attn_with_kvcache
,
is_fa_version_supported
)
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
),
(
16
,
2
)]
HEAD_SIZES
=
[
128
,
256
]
...
...
@@ -84,6 +87,7 @@ if not current_platform():
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
use_out
:
bool
,
...
...
@@ -95,8 +99,13 @@ if not current_platform():
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
sliding_window
:
Optional
[
int
],
fa_version
:
int
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
if
not
is_fa_version_supported
(
fa_version
):
pytest
.
skip
(
f
"Flash attention version
{
fa_version
}
not supported due "
f
"to:
\"
{
fa_version_unsupported_reason
(
fa_version
)
}
\"
"
)
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
...
...
@@ -135,6 +144,7 @@ if not current_platform():
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
window_size
=
window_size
,
fa_version
=
fa_version
,
)
output
=
output
if
not
use_out
else
out
output
=
output
.
squeeze
(
1
)
...
...
@@ -150,9 +160,8 @@ if not current_platform():
sliding_window
=
sliding_window
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
@
pytest
.
mark
.
parametrize
(
"use_out"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
[[(
1
,
1328
),
(
5
,
18
),
...
...
@@ -164,6 +173,7 @@ if not current_platform():
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
torch
.
inference_mode
()
def
test_varlen_with_paged_kv
(
use_out
:
bool
,
...
...
@@ -175,8 +185,12 @@ def test_varlen_with_paged_kv(
block_size
:
int
,
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
fa_version
:
int
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
if
not
is_fa_version_supported
(
fa_version
):
pytest
.
skip
(
f
"Flash attention version
{
fa_version
}
not supported due "
f
"to:
\"
{
fa_version_unsupported_reason
(
fa_version
)
}
\"
"
)
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
...
...
@@ -206,6 +220,7 @@ def test_varlen_with_paged_kv(
cu_kv_lens
=
torch
.
tensor
([
0
]
+
kv_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
kv_lens
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
...
...
@@ -230,6 +245,7 @@ def test_varlen_with_paged_kv(
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
# fa_version=fa_version,
)
else
:
output
=
flash_attn_varlen_func
(
...
...
@@ -238,7 +254,7 @@ def test_varlen_with_paged_kv(
v
=
value_cache
,
out
=
out
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_
kv_lens
,
seqused_k
=
kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
...
...
@@ -246,7 +262,9 @@ def test_varlen_with_paged_kv(
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
fa_version
=
fa_version
,
)
output
=
output
if
not
use_out
else
out
ref_output
=
ref_paged_attn
(
...
...
tests/kernels/test_mha_attn.py
0 → 100644
View file @
afd0da21
"""
Test:
* Tests for MultiHeadAttention layer
"""
from
unittest.mock
import
patch
import
pytest
import
torch
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.attention.selector
import
_Backend
,
_cached_get_attn_backend
from
vllm.platforms
import
current_platform
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.rocm
import
RocmPlatform
@
pytest
.
fixture
(
autouse
=
True
)
def
clear_cache
():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend
.
cache_clear
()
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
def
test_mha_attn_platform
(
device
:
str
):
"""
Test the attention selector between different platform and device.
"""
torch
.
set_default_dtype
(
torch
.
float16
)
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
CpuPlatform
()):
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
_Backend
.
TORCH_SDPA
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
RocmPlatform
()):
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
_Backend
.
TORCH_SDPA
else
:
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
_Backend
.
XFORMERS
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
attn
=
MultiHeadAttention
(
16
,
72
,
scale
=
1
)
assert
attn
.
attn_backend
==
_Backend
.
XFORMERS
def
ref_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
)
->
torch
.
Tensor
:
"""
Native implementation of scaled dot product attention without mask:
- query, key, value: [batch_size, seq_len, num_heads, head_size]
- attn_mask: [batch_size, seq_len, seq_len]
"""
query
,
key
,
value
=
(
x
.
transpose
(
1
,
2
)
for
x
in
(
query
,
key
,
value
))
attn_weights
=
scale
*
torch
.
matmul
(
query
,
key
.
transpose
(
2
,
3
))
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
matmul
(
attn_weights
,
value
).
transpose
(
1
,
2
)
return
out
BATCH_SIZES
=
[
1
,
16
]
SEQ_LENS
=
[
1
]
NUM_HEADS
=
[
1
,
16
]
NUM_KV_HEADS
=
[
1
]
HEAD_SIZES
=
[
64
,
80
]
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
if
not
current_platform
.
is_rocm
()
else
[
torch
.
half
,
torch
.
bfloat16
]
CUDA_DEVICES
=
[
"cuda"
]
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_kv_heads"
,
NUM_KV_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_mha_attn_forward
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
):
current_platform
.
seed_everything
(
0
)
torch
.
set_default_device
(
device
)
torch
.
set_default_dtype
(
dtype
)
q
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
k
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
)
v
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
)
scale
=
1.0
/
head_size
**
0.5
attn
=
MultiHeadAttention
(
num_heads
,
head_size
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
)
output
=
attn
(
q
,
k
,
v
)
assert
num_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_heads
//
num_kv_heads
q
=
q
.
reshape
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
k
=
k
.
reshape
(
batch_size
,
seq_len
,
num_kv_heads
,
head_size
)
v
=
v
.
reshape
(
batch_size
,
seq_len
,
num_kv_heads
,
head_size
)
if
num_queries_per_kv
>
1
:
k
=
torch
.
repeat_interleave
(
k
,
num_queries_per_kv
,
dim
=
2
)
v
=
torch
.
repeat_interleave
(
v
,
num_queries_per_kv
,
dim
=
2
)
ref_output
=
ref_attention
(
q
,
k
,
v
,
scale
=
scale
,
).
reshape
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
torch
.
testing
.
assert_close
(
output
,
ref_output
)
tests/kernels/test_moe.py
View file @
afd0da21
...
...
@@ -14,8 +14,12 @@ from vllm import _custom_ops as ops
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
moe_align_block_size
)
from
vllm.model_executor.layers.fused_moe.moe_torch_iterative
import
(
fused_moe
as
iterative_moe
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
marlin_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
quantize_weights
)
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
...
...
@@ -46,8 +50,102 @@ def test_fused_moe(
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
iterative_output
=
iterative_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch
.
testing
.
assert_close
(
iterative_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
32
,
222
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"has_zp"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"weight_bits"
,
[
4
,
8
])
def
test_fused_moe_wn16
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
has_zp
:
bool
,
weight_bits
:
int
):
print
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
group_size
,
has_zp
,
weight_bits
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
if
weight_bits
==
4
:
pack_factor
=
2
quant_type
=
scalar_types
.
uint4
if
has_zp
else
scalar_types
.
uint4b8
elif
weight_bits
==
8
:
pack_factor
=
1
quant_type
=
scalar_types
.
uint8
if
has_zp
else
scalar_types
.
uint8b128
w1_ref
=
w1
.
clone
()
w2_ref
=
w2
.
clone
()
w1_qweight
=
torch
.
empty
((
e
,
2
*
n
,
k
//
pack_factor
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w2_qweight
=
torch
.
empty
((
e
,
k
,
n
//
pack_factor
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w1_scales
=
torch
.
empty
((
e
,
2
*
n
,
k
//
group_size
),
device
=
"cuda"
,
dtype
=
dtype
)
w2_scales
=
torch
.
empty
((
e
,
k
,
n
//
group_size
),
device
=
"cuda"
,
dtype
=
dtype
)
w1_qzeros
=
torch
.
empty
((
e
,
2
*
n
//
pack_factor
,
k
//
group_size
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w2_qzeros
=
torch
.
empty
((
e
,
k
//
pack_factor
,
n
//
group_size
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
for
i
in
range
(
e
*
2
):
expert_id
=
i
%
e
if
i
//
e
==
0
:
w
,
w_ref
,
w_qweight
,
w_scales
,
w_qzeros
=
\
w1
,
w1_ref
,
w1_qweight
,
w1_scales
,
w1_qzeros
else
:
w
,
w_ref
,
w_qweight
,
w_scales
,
w_qzeros
=
\
w2
,
w2_ref
,
w2_qweight
,
w2_scales
,
w2_qzeros
weight
,
qweight
,
scales
,
qzeros
=
quantize_weights
(
w
[
expert_id
].
T
,
quant_type
,
group_size
,
has_zp
,
False
)
weight
=
weight
.
T
qweight
=
qweight
.
T
.
contiguous
().
to
(
torch
.
uint8
)
scales
=
scales
.
T
if
has_zp
:
qzeros
=
qzeros
.
T
.
contiguous
().
to
(
torch
.
uint8
)
if
weight_bits
==
4
:
qweight
=
qweight
[:,
1
::
2
]
*
16
+
qweight
[:,
::
2
]
if
has_zp
:
qzeros
=
qzeros
[
1
::
2
,
:]
*
16
+
qzeros
[::
2
,
:]
w_ref
[
expert_id
]
=
weight
w_qweight
[
expert_id
]
=
qweight
w_scales
[
expert_id
]
=
scales
if
has_zp
:
w_qzeros
[
expert_id
]
=
qzeros
triton_output
=
fused_moe
(
a
,
w1_qweight
,
w2_qweight
,
score
,
topk
,
renormalize
=
False
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
w1_scale
=
w1_scales
,
w2_scale
=
w2_scales
,
w1_zp
=
w1_qzeros
if
has_zp
else
None
,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
group_size
])
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
torch
.
inference_mode
()
...
...
tests/kernels/test_prefix_prefill.py
View file @
afd0da21
...
...
@@ -140,6 +140,7 @@ def test_contexted_kv_attention(
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache
=
v_cache
.
view
(
-
1
,
block_size
,
num_kv_heads
,
head_size
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
...
...
@@ -155,6 +156,8 @@ def test_contexted_kv_attention(
b_seq_len
,
b_ctx_len
,
max_input_len
,
k_scale
,
v_scale
,
sliding_window
=
sliding_window
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
...
...
@@ -170,6 +173,8 @@ def test_contexted_kv_attention(
b_seq_len
,
b_ctx_len
,
max_input_len
,
k_scale
,
v_scale
,
sliding_window
=
sliding_window
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
...
...
@@ -369,6 +374,7 @@ def test_contexted_kv_attention_alibi(
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache
=
v_cache
.
view
(
-
1
,
block_size
,
num_kv_heads
,
head_size
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
...
...
@@ -384,6 +390,8 @@ def test_contexted_kv_attention_alibi(
b_seq_len
,
b_ctx_len
,
max_input_len
,
k_scale
,
v_scale
,
alibi_slopes
=
alibi_slopes
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
...
...
@@ -399,6 +407,8 @@ def test_contexted_kv_attention_alibi(
b_seq_len
,
b_ctx_len
,
max_input_len
,
k_scale
,
v_scale
,
alibi_slopes
=
alibi_slopes
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
...
...
tests/kernels/test_triton_decode_attention.py
0 → 100644
View file @
afd0da21
import
pytest
import
torch
from
vllm.attention.ops.triton_decode_attention
import
decode_attention_fwd
def
cdiv
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
@
pytest
.
mark
.
parametrize
(
"B"
,
[
3
,
5
])
@
pytest
.
mark
.
parametrize
(
"L"
,
[
1027
,
1025
])
@
pytest
.
mark
.
parametrize
(
"H_Q"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"H_KV"
,
[
32
,
8
])
@
pytest
.
mark
.
parametrize
(
"D_QK"
,
[
128
,
192
,
576
])
@
pytest
.
mark
.
parametrize
(
"D_V"
,
[
128
,
512
])
@
pytest
.
mark
.
parametrize
(
"CACHE_SIZE"
,
[
16384
])
@
pytest
.
mark
.
parametrize
(
"PAGE_SIZE"
,
[
1
,
16
])
def
test_decode_attention
(
B
,
L
,
H_Q
,
H_KV
,
D_QK
,
D_V
,
CACHE_SIZE
,
PAGE_SIZE
):
assert
CACHE_SIZE
%
PAGE_SIZE
==
0
dtype
=
torch
.
bfloat16
seq_len
=
L
# This represents the number of tokens already in the sequence
sm_scale
=
1.0
/
(
D_QK
**
0.5
)
num_kv_splits
=
8
num_pages_per_batch
=
cdiv
(
seq_len
,
PAGE_SIZE
)
req_to_page
=
torch
.
randint
(
0
,
CACHE_SIZE
//
PAGE_SIZE
,
(
B
,
num_pages_per_batch
,
1
),
device
=
"cuda"
)
req_to_token
=
req_to_page
*
PAGE_SIZE
req_to_token
=
req_to_token
.
expand
(
B
,
num_pages_per_batch
,
PAGE_SIZE
)
req_to_token
=
req_to_token
+
torch
.
arange
(
PAGE_SIZE
,
device
=
"cuda"
).
view
(
1
,
1
,
-
1
)
req_to_token
=
req_to_token
.
view
(
B
,
-
1
)
req_to_token
=
req_to_token
[:,
:
seq_len
].
contiguous
()
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D_QK
,
dtype
=
dtype
,
device
=
"cuda"
)
# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_QK
,
dtype
=
dtype
,
device
=
"cuda"
)
v_buffer
=
torch
.
randn
(
CACHE_SIZE
,
H_KV
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
# o will have the same shape as q
o
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,
),
seq_len
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
# Call the original implementation.
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
)
# Page size can be larger than 1.
k_buffer
=
k_buffer
.
view
(
CACHE_SIZE
//
PAGE_SIZE
,
PAGE_SIZE
,
H_KV
,
D_QK
)
v_buffer
=
v_buffer
.
view
(
CACHE_SIZE
//
PAGE_SIZE
,
PAGE_SIZE
,
H_KV
,
D_V
)
o1
=
torch
.
zeros_like
(
o
)
decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o1
,
req_to_page
,
b_seq_len
,
attn_logits
,
num_kv_splits
,
sm_scale
,
PAGE_SIZE
,
)
assert
torch
.
allclose
(
o
,
o1
)
tests/kernels/test_triton_scaled_mm.py
View file @
afd0da21
...
...
@@ -39,6 +39,23 @@ def get_8bit_types():
return
types
# This test is to check regressions for int8 support on ROCm.
@
pytest
.
mark
.
parametrize
(
"model_path"
,
[
"neuralmagic/Llama-3.2-1B-quantized.w8a8"
,
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"Should only run on ROCm"
)
def
test_rocm_compressed_tensors_w8a8
(
vllm_runner
,
example_prompts
,
model_path
,
max_tokens
,
num_logprobs
):
dtype
=
"bfloat16"
with
vllm_runner
(
model_path
,
dtype
=
dtype
)
as
vllm_model
:
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
@
pytest
.
mark
.
parametrize
(
"M"
,
[
1
,
33
,
64
,
512
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
256
,
971
,
20486
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
,
496
,
1024
])
...
...
tests/kernels/untest_flashinfer.py
View file @
afd0da21
...
...
@@ -133,17 +133,19 @@ def test_flashinfer_decode_with_paged_kv(
use_tensor_cores
=
(
(
num_query_heads
//
num_kv_heads
)
>
4
)
)
wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_query_heads
,
num_kv_heads
,
head_size
,
block_size
,
"NONE"
,
data_type
=
dtype
)
output
=
wrapper
.
forward
(
query
,
key_value_cache
,
logits_soft_cap
=
soft_cap
)
wrapper
.
plan
(
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_query_heads
,
num_kv_heads
,
head_size
,
block_size
,
"NONE"
,
q_data_type
=
dtype
,
kv_data_type
=
dtype
,
logits_soft_cap
=
soft_cap
)
output
=
wrapper
.
run
(
query
,
key_value_cache
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
...
...
@@ -228,7 +230,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
wrapper
.
begin_forward
(
wrapper
.
plan
(
qo_indptr
,
kv_indptr
,
kv_indices
,
...
...
@@ -237,12 +239,14 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
num_kv_heads
,
head_size
,
block_size
,
q_data_type
=
dtype
,
kv_data_type
=
dtype
,
logits_soft_cap
=
soft_cap
,
)
output
=
wrapper
.
forward
(
output
=
wrapper
.
run
(
query
,
key_value_cache
,
logits_soft_cap
=
soft_cap
,
)
ref_output
=
ref_paged_attn
(
query
=
query
,
...
...
@@ -253,7 +257,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1
e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
5
e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
...
...
@@ -332,7 +336,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
wrapper
.
begin_forward
(
wrapper
.
plan
(
qo_indptr
,
kv_indptr
,
kv_indices
,
...
...
@@ -341,13 +345,12 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
num_kv_heads
,
head_size
,
block_size
,
q_data_type
=
dtype
,
kv_data_type
=
kv_cache_dtype
,
logits_soft_cap
=
soft_cap
,
)
output
=
wrapper
.
forward
(
query
,
kv_cache_fp8
,
logits_soft_cap
=
soft_cap
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
output
=
wrapper
.
run
(
query
,
kv_cache_fp8
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
.
squeeze
(
1
),
...
...
@@ -360,7 +363,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
del
query
del
block_tables
# verify prefill fp8
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1
e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
5
e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
...
...
@@ -439,21 +442,18 @@ def test_flashinfer_decode_with_paged_fp8_kv(
wrapper
=
flashinfer
.
\
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_query_heads
,
num_kv_heads
,
head_size
,
block_size
,
"NONE"
,
data_type
=
dtype
,
q_data_type
=
dtype
)
output
=
wrapper
.
forward
(
query
,
kv_cache_fp8
,
logits_soft_cap
=
soft_cap
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
wrapper
.
plan
(
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_query_heads
,
num_kv_heads
,
head_size
,
block_size
,
"NONE"
,
q_data_type
=
dtype
,
kv_data_type
=
kv_cache_dtype
,
logits_soft_cap
=
soft_cap
)
output
=
wrapper
.
run
(
query
,
kv_cache_fp8
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
key_cache
=
key_value_cache
[:,
0
,
:,
:,
:].
squeeze
(
1
)
value_cache
=
key_value_cache
[:,
1
,
:,
:,
:].
squeeze
(
1
)
...
...
tests/kernels/utils.py
View file @
afd0da21
...
...
@@ -5,7 +5,7 @@ import random
import
unittest
from
numbers
import
Number
from
typing
import
(
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Sequence
,
Tuple
,
Union
)
Type
,
Union
)
import
pytest
import
torch
...
...
@@ -13,6 +13,7 @@ from torch._prims_common import TensorLikeType
from
vllm.attention
import
AttentionBackend
,
AttentionMetadata
,
AttentionType
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.platforms.interface
import
_Backend
from
vllm.utils
import
(
STR_BACKEND_ENV_VAR
,
STR_FLASH_ATTN_VAL
,
STR_XFORMERS_ATTN_VAL
,
make_tensor_with_pad
)
...
...
@@ -790,7 +791,7 @@ def make_block_tables_slot_mapping(
def
make_test_metadata
(
attn_backend
:
Attention
Backend
,
attn_backend
:
_
Backend
,
is_prompt
:
bool
,
seq_lens
:
Optional
[
List
[
int
]],
decoder_test_params
:
Optional
[
PhaseTestParameters
],
...
...
@@ -815,7 +816,7 @@ def make_test_metadata(
Arguments:
* attn_backend: Backend for sourcing attention kernels
* attn_backend
_name
: Backend for sourcing attention kernels
* is_prompt: prefill if True, o/w decode
* seq_lens: list of token counts for each sequence
* decoder_test_params: decoder self-attention test params;
...
...
@@ -882,6 +883,8 @@ def make_test_metadata(
# (kv_mmap)
cross_kv_mmap
=
cross_test_params
.
kv_mmap
attn_backend_obj
=
make_backend
(
attn_backend
.
name
)
if
is_prompt
:
# Prefill-phase scenario
...
...
@@ -902,11 +905,11 @@ def make_test_metadata(
context_lens
,
encoder_seq_lens
,
device
=
device
)
return
attn_backend
.
make_metadata
(
return
attn_backend_obj
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
(
None
if
kv_mmap
is
None
else
kv_mmap
.
slot_mapping
),
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
True
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
...
...
@@ -952,10 +955,11 @@ def make_test_metadata(
encoder_seq_lens
,
device
=
device
)
return
attn_backend
.
make_metadata
(
return
attn_backend
_obj
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
kv_mmap
.
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
True
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
...
...
@@ -1096,3 +1100,56 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
kwargs
,
test_utils
=
test_utils
,
raise_exception
=
raise_exception
)
if
cond
else
{}
# For testing quantized linear kernels
def
to_fp8
(
tensor
:
torch
.
Tensor
):
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
torch
.
round
(
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
to_int8
(
tensor
:
torch
.
Tensor
):
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match the
# the target shape by repeating the data along that dimension (broadcasting)
# , we extend these semantics to say if the extent of a dimension in the
# source shape is not 1 and does not match the target shape we repeat each
# element along that dimension src_shape[dim] // target_shape[dim] times
# example if we have:
# a = [[1, 2], and target_shape = (2, 4)
# [3, 4]]
# then we would expand a to:
# a = [[1, 1, 2, 2],
# [3, 3, 4, 4]]
# NOTE this function this function does not explicitly broadcast dimensions
# with an extent of 1, since this can be done implicitly by pytorch
def
group_broadcast
(
t
,
shape
):
for
i
,
s
in
enumerate
(
shape
):
if
t
.
shape
[
i
]
!=
s
and
t
.
shape
[
i
]
!=
1
:
assert
s
%
t
.
shape
[
i
]
==
0
t
=
t
.
unsqueeze
(
i
+
1
)
\
.
expand
(
*
t
.
shape
[:
i
+
1
],
s
//
t
.
shape
[
i
],
*
t
.
shape
[
i
+
1
:])
\
.
flatten
(
i
,
i
+
1
)
return
t
scale_a
=
group_broadcast
(
scale_a
,
a
.
shape
)
scale_b
=
group_broadcast
(
scale_b
,
b
.
shape
)
output
=
torch
.
mm
((
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
)),
(
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
))).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
Prev
1
…
17
18
19
20
21
22
23
24
25
…
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment