Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7e63ef82
Commit
7e63ef82
authored
Jan 21, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.14.0' into v0.14.0-dev
parents
8cbcac5d
b17039bc
Changes
681
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
531 additions
and
106 deletions
+531
-106
tests/kernels/attention/test_cpu_attn.py
tests/kernels/attention/test_cpu_attn.py
+2
-1
tests/kernels/attention/test_flash_attn.py
tests/kernels/attention/test_flash_attn.py
+2
-1
tests/kernels/attention/test_flashinfer_trtllm_attention.py
tests/kernels/attention/test_flashinfer_trtllm_attention.py
+3
-2
tests/kernels/attention/test_flashmla.py
tests/kernels/attention/test_flashmla.py
+2
-2
tests/kernels/attention/test_flashmla_sparse.py
tests/kernels/attention/test_flashmla_sparse.py
+3
-3
tests/kernels/attention/test_lightning_attn.py
tests/kernels/attention/test_lightning_attn.py
+4
-4
tests/kernels/attention/test_merge_attn_states.py
tests/kernels/attention/test_merge_attn_states.py
+2
-2
tests/kernels/attention/test_mha_attn.py
tests/kernels/attention/test_mha_attn.py
+74
-15
tests/kernels/attention/test_pack_unpack_triton.py
tests/kernels/attention/test_pack_unpack_triton.py
+1
-1
tests/kernels/attention/test_prefix_prefill.py
tests/kernels/attention/test_prefix_prefill.py
+40
-7
tests/kernels/attention/test_rocm_attention_selector.py
tests/kernels/attention/test_rocm_attention_selector.py
+44
-26
tests/kernels/attention/test_triton_decode_attention.py
tests/kernels/attention/test_triton_decode_attention.py
+1
-1
tests/kernels/attention/test_triton_prefill_attention.py
tests/kernels/attention/test_triton_prefill_attention.py
+225
-0
tests/kernels/attention/test_triton_unified_attention.py
tests/kernels/attention/test_triton_unified_attention.py
+3
-2
tests/kernels/attention/untest_attention_selector.py
tests/kernels/attention/untest_attention_selector.py
+29
-27
tests/kernels/attention/untest_flashinfer.py
tests/kernels/attention/untest_flashinfer.py
+5
-4
tests/kernels/core/test_activation.py
tests/kernels/core/test_activation.py
+5
-3
tests/kernels/core/test_fused_qk_norm_rope.py
tests/kernels/core/test_fused_qk_norm_rope.py
+8
-3
tests/kernels/core/test_fused_quant_layernorm.py
tests/kernels/core/test_fused_quant_layernorm.py
+1
-0
tests/kernels/core/test_layernorm.py
tests/kernels/core/test_layernorm.py
+77
-2
No files found.
Too many changes to show.
To preserve performance only
681 of 681+
files are displayed.
Plain diff
Email patch
tests/kernels/attention/test_cpu_attn.py
View file @
7e63ef82
...
...
@@ -8,6 +8,7 @@ import pytest
import
torch
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.attention.backends.cpu_attn
import
_get_attn_isa
if
not
current_platform
.
is_cpu
():
...
...
@@ -190,7 +191,7 @@ def varlen_with_paged_kv(
use_sink
:
bool
,
isa
:
str
,
)
->
None
:
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
...
...
tests/kernels/attention/test_flash_attn.py
View file @
7e63ef82
...
...
@@ -6,6 +6,7 @@ import pytest
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
try
:
if
current_platform
.
is_rocm
():
...
...
@@ -132,7 +133,7 @@ def test_varlen_with_paged_kv(
"Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
...
...
tests/kernels/attention/test_flashinfer_trtllm_attention.py
View file @
7e63ef82
...
...
@@ -10,6 +10,7 @@ from tests.kernels.quantization.nvfp4_utils import (
)
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.torch_utils
import
set_random_seed
if
not
current_platform
.
is_device_capability_family
(
100
):
pytest
.
skip
(
...
...
@@ -80,7 +81,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
has_sinks
:
bool
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
42
)
set_random_seed
(
42
)
q_quant_dtype
,
kv_quant_dtype
,
o_quant_dtype
=
quant_dtypes
q_quant_dtype
=
q_quant_dtype
or
dtype
...
...
@@ -279,7 +280,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
has_sinks
:
bool
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
42
)
set_random_seed
(
42
)
q_quant_dtype
,
kv_quant_dtype
,
o_quant_dtype
=
quant_dtypes
q_quant_dtype
=
q_quant_dtype
or
dtype
...
...
tests/kernels/attention/test_flashmla.py
View file @
7e63ef82
...
...
@@ -7,12 +7,12 @@ import random
import
pytest
import
torch
from
vllm.attention.ops.flashmla
import
(
from
vllm.triton_utils
import
triton
from
vllm.v1.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
get_mla_metadata
,
is_flashmla_dense_supported
,
)
from
vllm.triton_utils
import
triton
def
cal_diff
(
...
...
tests/kernels/attention/test_flashmla_sparse.py
View file @
7e63ef82
...
...
@@ -5,7 +5,7 @@ import torch
def
test_sparse_flashmla_metadata_smoke
():
import
vllm.attention.ops.flashmla
as
fm
import
vllm.
v1.
attention.ops.flashmla
as
fm
ok
,
reason
=
fm
.
is_flashmla_sparse_supported
()
if
not
ok
:
...
...
@@ -34,7 +34,7 @@ def test_sparse_flashmla_metadata_smoke():
def
test_sparse_flashmla_decode_smoke
():
import
vllm.attention.ops.flashmla
as
fm
import
vllm.
v1.
attention.ops.flashmla
as
fm
ok
,
reason
=
fm
.
is_flashmla_sparse_supported
()
if
not
ok
:
...
...
@@ -97,7 +97,7 @@ def test_sparse_flashmla_decode_smoke():
def
test_sparse_flashmla_prefill_smoke
():
import
vllm.attention.ops.flashmla
as
fm
import
vllm.
v1.
attention.ops.flashmla
as
fm
ok
,
reason
=
fm
.
is_flashmla_sparse_supported
()
if
not
ok
:
...
...
tests/kernels/attention/test_lightning_attn.py
View file @
7e63ef82
...
...
@@ -5,7 +5,7 @@ import pytest
import
torch
from
vllm.model_executor.layers.lightning_attn
import
linear_decode_forward_triton
from
vllm.
platforms
import
current_platform
from
vllm.
utils.torch_utils
import
set_random_seed
NUM_HEADS
=
[
4
,
8
]
HEAD_SIZES
=
[
64
]
...
...
@@ -124,7 +124,7 @@ def test_linear_decode_forward_triton(
torch
.
set_default_device
(
"cuda"
)
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed_all
(
42
)
current_platform
.
seed_everything
(
42
)
set_random_seed
(
42
)
base
=
0.01
q
=
base
*
torch
.
randn
(
batch_size
,
num_heads
,
1
,
head_size
,
dtype
=
dtype
)
k
=
base
*
torch
.
randn
(
batch_size
,
num_heads
,
1
,
head_size
,
dtype
=
dtype
)
...
...
@@ -167,7 +167,7 @@ def test_linear_decode_forward_triton_with_padding(
torch
.
set_default_device
(
"cuda"
)
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed_all
(
42
)
current_platform
.
seed_everything
(
42
)
set_random_seed
(
42
)
batch_size
=
4
base
=
0.01
...
...
@@ -231,7 +231,7 @@ def test_lightning_attention_reference(
torch
.
set_default_device
(
"cuda"
)
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed_all
(
42
)
current_platform
.
seed_everything
(
42
)
set_random_seed
(
42
)
base
=
0.01
q
=
base
*
torch
.
randn
(
batch_size
,
num_heads
,
seq_len
,
head_size
,
dtype
=
dtype
)
...
...
tests/kernels/attention/test_merge_attn_states.py
View file @
7e63ef82
...
...
@@ -5,10 +5,10 @@ import pytest
import
torch
from
vllm._custom_ops
import
merge_attn_states
as
merge_attn_states_cuda
from
vllm.attention.ops.triton_merge_attn_states
import
(
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.ops.triton_merge_attn_states
import
(
merge_attn_states
as
merge_attn_states_triton
,
)
from
vllm.platforms
import
current_platform
# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
...
...
tests/kernels/attention/test_mha_attn.py
View file @
7e63ef82
...
...
@@ -3,21 +3,23 @@
"""
Test:
* Tests for M
ultiHead
Attention layer
* Tests for M
MEncoder
Attention layer
"""
import
itertools
from
unittest.mock
import
patch
import
pytest
import
torch
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.attention.selector
import
_cached_get_attn_backend
from
vllm.model_executor.layers.attention.mm_encoder_attention
import
MMEncoderAttention
from
vllm.platforms
import
current_platform
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.selector
import
_cached_get_attn_backend
@
pytest
.
fixture
(
autouse
=
True
)
...
...
@@ -34,7 +36,7 @@ if current_platform.is_rocm():
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_mha_attn_platform
(
device
:
str
):
def
test_mha_attn_platform
(
default_vllm_config
,
device
:
str
):
"""
Test the attention selector between different platform and device.
"""
...
...
@@ -42,35 +44,31 @@ def test_mha_attn_platform(device: str):
if
device
==
"cpu"
:
with
(
patch
(
"vllm.attention.layer.current_platform"
,
CpuPlatform
()),
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CpuPlatform
()),
):
attn
=
M
ultiHead
Attention
(
16
,
64
,
scale
=
1
)
attn
=
M
MEncoder
Attention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
TORCH_SDPA
elif
device
==
"hip"
:
with
(
patch
(
"vllm.attention.layer.current_platform"
,
RocmPlatform
()),
patch
(
"vllm.model_executor.models.vision.current_platform"
,
RocmPlatform
()),
):
attn
=
M
ultiHead
Attention
(
16
,
64
,
scale
=
1
)
attn
=
M
MEncoder
Attention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
else
:
# Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention
with
(
patch
(
"vllm.attention.layer.current_platform"
,
CudaPlatform
()),
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CudaPlatform
()),
):
attn
=
M
ultiHead
Attention
(
16
,
64
,
scale
=
1
)
attn
=
M
MEncoder
Attention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
# Test CUDA with head_size=72 (not divisible by 32)
# - should use vLLM's FlashAttention
with
(
patch
(
"vllm.attention.layer.current_platform"
,
CudaPlatform
()),
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CudaPlatform
()),
):
attn
=
M
ultiHead
Attention
(
16
,
72
,
scale
=
1
)
attn
=
M
MEncoder
Attention
(
16
,
72
,
scale
=
1
)
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
...
...
@@ -94,6 +92,10 @@ def ref_attention(
BATCH_SIZES
=
[
1
,
16
]
SEQ_LENS
=
[
1
]
VAR_SEQ_LENS
=
[
[
2
,
2
],
[
2
,
3
,
4
],
]
NUM_HEADS
=
[
1
,
16
]
NUM_KV_HEADS
=
[
1
]
HEAD_SIZES
=
[
64
,
80
]
...
...
@@ -114,6 +116,7 @@ CUDA_DEVICES = ["cuda"]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_mha_attn_forward
(
default_vllm_config
,
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
...
...
@@ -122,7 +125,7 @@ def test_mha_attn_forward(
dtype
:
torch
.
dtype
,
device
:
str
,
):
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
torch
.
set_default_device
(
device
)
torch
.
set_default_dtype
(
dtype
)
...
...
@@ -130,7 +133,7 @@ def test_mha_attn_forward(
k
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
)
v
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
)
scale
=
1.0
/
head_size
**
0.5
attn
=
M
ultiHead
Attention
(
attn
=
M
MEncoder
Attention
(
num_heads
,
head_size
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
)
output
=
attn
(
q
,
k
,
v
)
...
...
@@ -151,3 +154,59 @@ def test_mha_attn_forward(
scale
=
scale
,
).
reshape
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
torch
.
testing
.
assert_close
(
output
,
ref_output
)
@
pytest
.
mark
.
parametrize
(
"var_seq_len"
,
VAR_SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_kv_heads"
,
NUM_KV_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_mha_attn_varlen_forward
(
default_vllm_config
,
var_seq_len
:
list
[
int
],
num_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
):
set_random_seed
(
0
)
torch
.
set_default_device
(
device
)
torch
.
set_default_dtype
(
dtype
)
q
=
torch
.
randn
(
1
,
sum
(
var_seq_len
),
num_heads
,
head_size
)
k
=
torch
.
randn
(
1
,
sum
(
var_seq_len
),
num_kv_heads
,
head_size
)
v
=
torch
.
randn
(
1
,
sum
(
var_seq_len
),
num_kv_heads
,
head_size
)
cu_seqlens
=
torch
.
tensor
(
[
0
]
+
list
(
itertools
.
accumulate
(
var_seq_len
)),
dtype
=
torch
.
int32
)
scale
=
1.0
/
head_size
**
0.5
attn
=
MMEncoderAttention
(
num_heads
,
head_size
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
)
output
=
attn
(
q
,
k
,
v
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
torch
.
tensor
(
max
(
var_seq_len
))
)
assert
num_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_heads
//
num_kv_heads
if
num_queries_per_kv
>
1
:
k
=
torch
.
repeat_interleave
(
k
,
num_queries_per_kv
,
dim
=
2
)
v
=
torch
.
repeat_interleave
(
v
,
num_queries_per_kv
,
dim
=
2
)
ref_output
=
[]
for
q_i
,
k_i
,
v_i
in
zip
(
torch
.
split
(
q
,
var_seq_len
,
dim
=
1
),
torch
.
split
(
k
,
var_seq_len
,
dim
=
1
),
torch
.
split
(
v
,
var_seq_len
,
dim
=
1
),
):
output_i
=
ref_attention
(
q_i
,
k_i
,
v_i
,
scale
=
scale
,
)
ref_output
.
append
(
output_i
)
ref_output
=
torch
.
cat
(
ref_output
,
dim
=
1
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
)
tests/kernels/attention/test_pack_unpack_triton.py
View file @
7e63ef82
...
...
@@ -4,7 +4,7 @@
import
torch
from
torch.testing
import
assert_close
from
vllm.attention.ops.common
import
pack_seq_triton
,
unpack_seq_triton
from
vllm.
v1.
attention.ops.common
import
pack_seq_triton
,
unpack_seq_triton
def
test_pack_seq_basic_fp8
():
...
...
tests/kernels/attention/test_prefix_prefill.py
View file @
7e63ef82
...
...
@@ -10,10 +10,12 @@ import pytest
import
torch
import
torch.nn.functional
as
F
from
vllm.attention.ops.chunked_prefill_paged_decode
import
chunked_prefill_paged_decode
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
set_random_seed
from
vllm.v1.attention.ops.chunked_prefill_paged_decode
import
(
chunked_prefill_paged_decode
,
)
from
vllm.v1.attention.ops.prefix_prefill
import
context_attention_fwd
if
not
current_platform
.
is_rocm
():
from
xformers
import
ops
as
xops
...
...
@@ -117,6 +119,7 @@ def test_contexted_kv_attention(
kv_cache_dtype
:
str
,
device
:
str
,
op
:
Callable
,
block_size
:
int
=
32
,
)
->
None
:
if
"fp8"
in
kv_cache_dtype
and
not
current_platform
.
has_device_capability
(
89
):
pytest
.
skip
(
...
...
@@ -130,7 +133,7 @@ def test_contexted_kv_attention(
):
pytest
.
skip
(
"ROCm custom paged attention does not support fp8_e5m2 KV cache"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
torch
.
set_default_device
(
device
)
# Need this, otherwise when we capture the graph the process
...
...
@@ -143,7 +146,6 @@ def test_contexted_kv_attention(
MAX_CTX_LEN
=
1024
BS
=
10
cache_size
=
640
block_size
=
32
max_block_per_request
=
64
query_lens
=
[
random
.
randint
(
16
,
MAX_SEQ_LEN
)
for
_
in
range
(
BS
)]
# ensure one sequence in batch is a decode
...
...
@@ -338,6 +340,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype
:
str
,
device
:
str
,
op
:
Callable
,
block_size
:
int
=
32
,
)
->
None
:
if
"fp8"
in
kv_cache_dtype
and
not
current_platform
.
has_device_capability
(
89
):
pytest
.
skip
(
...
...
@@ -351,7 +354,7 @@ def test_contexted_kv_attention_alibi(
):
pytest
.
skip
(
"ROCm custom paged attention does not support fp8_e5m2 KV cache"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
torch
.
set_default_device
(
device
)
# Need this, otherwise when we capture the graph the process
...
...
@@ -390,7 +393,6 @@ def test_contexted_kv_attention_alibi(
MAX_CTX_LEN
=
1024
BS
=
10
cache_size
=
640
block_size
=
32
max_block_per_request
=
64
query_lens
=
[
random
.
randint
(
16
,
MAX_SEQ_LEN
)
for
_
in
range
(
BS
)]
ctx_lens
=
[
random
.
randint
(
16
,
MAX_CTX_LEN
)
for
_
in
range
(
BS
)]
...
...
@@ -643,3 +645,34 @@ def test_contexted_kv_attention_alibi_f32(
test_contexted_kv_attention_alibi
(
num_heads
,
num_queries_per_kv
,
head_size
,
dtype
,
kv_cache_dtype
,
device
,
op
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"op"
,
OPS
)
@
torch
.
inference_mode
()
def
test_qwen3_nonstandard_block_size
(
head_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
op
:
Callable
,
)
->
None
:
"""
A separate test function specifically added
for Qwen3-Next-80B (Block Size 544).
"""
if
not
current_platform
.
is_rocm
():
pytest
.
skip
(
"544 block size optimization is only for ROCm."
)
test_contexted_kv_attention
(
num_heads
=
64
,
num_queries_per_kv
=
1
,
head_size
=
head_size
,
block_size
=
544
,
sliding_window
=
0
,
dtype
=
dtype
,
kv_cache_dtype
=
"auto"
,
device
=
device
,
op
=
op
,
)
tests/kernels/attention/test_rocm_attention_selector.py
View file @
7e63ef82
...
...
@@ -4,8 +4,10 @@
import
pytest
import
torch
from
vllm.
attention.selector
import
_cached_get_attn_backend
,
get_attn_backend
from
vllm.
config
import
AttentionConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.selector
import
_cached_get_attn_backend
,
get_attn_backend
@
pytest
.
fixture
(
autouse
=
True
)
...
...
@@ -16,40 +18,56 @@ def clear_cache():
@
pytest
.
mark
.
skip
(
reason
=
"Skipped for now. Should be revisited."
)
def
test_selector
(
monkeypatch
:
pytest
.
MonkeyPatch
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"ROCM_ATTN"
)
# Set the current platform to ROCm using monkeypatch
m
onkeypatch
.
setattr
(
"vllm.v1.attention.selector.current_platform"
,
RocmPlatform
()
)
# Set the current platform to ROCm using monkeypatch
monkeypatch
.
setattr
(
"vllm.attention.selector.current_platform"
,
RocmPlatform
())
# Test standard ROCm attention
attention_config
=
AttentionConfig
(
backend
=
AttentionBackendEnum
.
ROCM_ATTN
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
# Test standard ROCm attention
with
set_current_vllm_config
(
vllm_config
):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
"ROCM_FLASH"
or
backend
.
get_name
()
==
"TRITON_ATTN"
# MLA test for deepseek related
# MLA test for deepseek related
# Change the attention backend to triton MLA
attention_config
=
AttentionConfig
(
backend
=
AttentionBackendEnum
.
TRITON_MLA
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
# change the attention backend to triton MLA
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"TRITON_MLA"
)
with
set_current_vllm_config
(
vllm_config
):
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
use_mla
=
True
)
assert
backend
.
get_name
()
==
"TRITON_MLA"
# If attention backend is None
# If use_mla is true
# The selected backend is triton MLA
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
""
)
# If attention backend is None
# If use_mla is true
# The selected backend is triton MLA
attention_config
=
AttentionConfig
(
backend
=
None
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
with
set_current_vllm_config
(
vllm_config
):
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
use_mla
=
True
)
assert
backend
.
get_name
()
==
"TRITON_MLA"
# change the attention backend to AITER MLA
# m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_MLA")
# backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
# assert backend.get_name() == "ROCM_AITER_MLA"
# # If attention backend is None
# # If use_mla is true
# # If VLLM_ROCM_USE_AITER is enabled
# # The selected backend is ROCM_AITER_MLA
# m.setenv("VLLM_ATTENTION_BACKEND", "")
# m.setenv("VLLM_ROCM_USE_AITER", "1")
# backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
# assert backend.get_name() == "ROCM_AITER_MLA"
# Change the attention backend to AITER MLA
attention_config
=
AttentionConfig
(
backend
=
AttentionBackendEnum
.
ROCM_AITER_MLA
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
# with set_current_vllm_config(vllm_config):
# backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
# assert backend.get_name() == "ROCM_AITER_MLA"
# # If attention backend is None
# # If use_mla is true
# # If VLLM_ROCM_USE_AITER is enabled
# # The selected backend is ROCM_AITER_MLA
# with monkeypatch.context() as m:
# m.setenv("VLLM_ROCM_USE_AITER", "1")
# attention_config = AttentionConfig(backend=None)
# vllm_config = VllmConfig(attention_config=attention_config)
# with set_current_vllm_config(vllm_config):
# backend = get_attn_backend(
# 576, torch.bfloat16, "auto", 1, False, use_mla=True
# )
# assert backend.get_name() == "ROCM_AITER_MLA"
tests/kernels/attention/test_triton_decode_attention.py
View file @
7e63ef82
...
...
@@ -4,8 +4,8 @@
import
pytest
import
torch
from
vllm.attention.ops.triton_decode_attention
import
decode_attention_fwd
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.ops.triton_decode_attention
import
decode_attention_fwd
@
pytest
.
mark
.
parametrize
(
"B"
,
[
3
,
5
])
...
...
tests/kernels/attention/test_triton_prefill_attention.py
0 → 100644
View file @
7e63ef82
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
import
torch.nn.functional
as
F
from
vllm.v1.attention.ops.triton_prefill_attention
import
context_attention_fwd
def
ref_masked_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
is_causal
:
bool
=
True
,
sliding_window_q
:
int
|
None
=
None
,
sliding_window_k
:
int
|
None
=
None
,
)
->
torch
.
Tensor
:
"""Reference implementation using PyTorch SDPA."""
# q, k, v: [total_tokens, num_heads, head_dim]
# SDPA expects [batch, num_heads, seq_len, head_dim]
total_tokens
=
q
.
shape
[
0
]
# Add batch dimension and transpose
q
=
q
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# [1, num_heads, total_tokens, head_dim]
k
=
k
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# [1, num_heads, total_tokens, head_dim]
v
=
v
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# [1, num_heads, total_tokens, head_dim]
# Create attention mask if needed
attn_mask
=
None
use_causal
=
is_causal
# If we have sliding window or need custom masking, create explicit mask
sliding_window_q
=
sliding_window_q
if
sliding_window_q
is
not
None
else
0
sliding_window_k
=
sliding_window_k
if
sliding_window_k
is
not
None
else
0
if
(
sliding_window_q
>
0
)
or
(
sliding_window_k
>
0
):
# Position indices
pos_q
=
torch
.
arange
(
total_tokens
,
device
=
q
.
device
).
unsqueeze
(
1
)
pos_k
=
torch
.
arange
(
total_tokens
,
device
=
q
.
device
).
unsqueeze
(
0
)
# Start with valid mask (False = no masking)
mask
=
torch
.
ones
(
(
total_tokens
,
total_tokens
),
dtype
=
torch
.
bool
,
device
=
q
.
device
)
# Apply causal mask
if
is_causal
:
mask
=
mask
&
(
pos_q
>=
pos_k
)
# Apply sliding window masks
sliding_window_mask
=
torch
.
ones_like
(
mask
)
if
sliding_window_q
>
0
:
sliding_window_mask
&=
pos_q
-
pos_k
<=
sliding_window_q
if
sliding_window_k
>
0
:
sliding_window_mask
&=
pos_k
-
pos_q
<=
sliding_window_k
mask
=
mask
&
sliding_window_mask
attn_mask
=
torch
.
where
(
mask
,
0.0
,
float
(
"-inf"
)).
to
(
q
.
dtype
)
use_causal
=
False
# Don't use is_causal when providing explicit mask
# Use SDPA
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attn_mask
,
is_causal
=
use_causal
,
dropout_p
=
0.0
)
# Convert back to original shape: [total_tokens, num_heads, head_dim]
output
=
output
.
transpose
(
1
,
2
).
squeeze
(
0
)
return
output
@
pytest
.
mark
.
parametrize
(
"B"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"H_Q"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"H_KV"
,
[
32
,
8
])
@
pytest
.
mark
.
parametrize
(
"D"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"is_causal"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
def
test_context_attention
(
B
:
int
,
max_seq_len
:
int
,
H_Q
:
int
,
H_KV
:
int
,
D
:
int
,
is_causal
:
bool
,
dtype
:
torch
.
dtype
,
):
"""Test basic context attention without sliding window."""
torch
.
manual_seed
(
42
)
# Generate random sequence lengths for each batch
seq_lens
=
torch
.
randint
(
max_seq_len
//
2
,
max_seq_len
+
1
,
(
B
,),
device
=
"cuda"
)
total_tokens
=
seq_lens
.
sum
().
item
()
# Create batch start locations
b_start_loc
=
torch
.
zeros
(
B
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
# Create input tensors
q
=
torch
.
randn
(
total_tokens
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
k
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
v
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros_like
(
q
)
# Call Triton kernel
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
seq_lens
,
max_seq_len
,
is_causal
=
is_causal
,
sliding_window_q
=
None
,
sliding_window_k
=
None
,
)
# Compute reference output for each sequence in batch
o_ref
=
torch
.
zeros_like
(
q
)
for
i
in
range
(
B
):
start
=
b_start_loc
[
i
].
item
()
end
=
start
+
seq_lens
[
i
].
item
()
q_seq
=
q
[
start
:
end
]
k_seq
=
k
[
start
:
end
]
v_seq
=
v
[
start
:
end
]
# Expand KV heads if using GQA
if
H_Q
!=
H_KV
:
kv_group_num
=
H_Q
//
H_KV
k_seq
=
k_seq
.
repeat_interleave
(
kv_group_num
,
dim
=
1
)
v_seq
=
v_seq
.
repeat_interleave
(
kv_group_num
,
dim
=
1
)
o_ref
[
start
:
end
]
=
ref_masked_attention
(
q_seq
,
k_seq
,
v_seq
,
is_causal
=
is_causal
,
sliding_window_q
=
None
,
sliding_window_k
=
None
,
)
# Compare outputs
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
@
pytest
.
mark
.
parametrize
(
"B"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"H_Q"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"H_KV"
,
[
32
,
8
])
@
pytest
.
mark
.
parametrize
(
"D"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[(
32
,
32
),
(
32
,
0
),
(
0
,
32
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
def
test_context_attention_sliding_window
(
B
:
int
,
max_seq_len
:
int
,
H_Q
:
int
,
H_KV
:
int
,
D
:
int
,
sliding_window
:
tuple
[
int
,
int
],
dtype
:
torch
.
dtype
,
):
sliding_window_q
,
sliding_window_k
=
sliding_window
"""Test context attention with sliding window."""
torch
.
manual_seed
(
42
)
# Generate random sequence lengths for each batch
seq_lens
=
torch
.
randint
(
max_seq_len
//
2
,
max_seq_len
+
1
,
(
B
,),
device
=
"cuda"
)
total_tokens
=
seq_lens
.
sum
().
item
()
# Create batch start locations
b_start_loc
=
torch
.
zeros
(
B
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
# Create input tensors
q
=
torch
.
randn
(
total_tokens
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
k
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
v
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros_like
(
q
)
# Call Triton kernel
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
seq_lens
,
max_seq_len
,
is_causal
=
False
,
sliding_window_q
=
sliding_window_q
,
sliding_window_k
=
sliding_window_k
,
)
# Compute reference output for each sequence in batch
o_ref
=
torch
.
zeros_like
(
q
)
for
i
in
range
(
B
):
start
=
b_start_loc
[
i
].
item
()
end
=
start
+
seq_lens
[
i
].
item
()
q_seq
=
q
[
start
:
end
]
k_seq
=
k
[
start
:
end
]
v_seq
=
v
[
start
:
end
]
# Expand KV heads if using GQA
if
H_Q
!=
H_KV
:
kv_group_num
=
H_Q
//
H_KV
k_seq
=
k_seq
.
repeat_interleave
(
kv_group_num
,
dim
=
1
)
v_seq
=
v_seq
.
repeat_interleave
(
kv_group_num
,
dim
=
1
)
o_ref
[
start
:
end
]
=
ref_masked_attention
(
q_seq
,
k_seq
,
v_seq
,
is_causal
=
False
,
sliding_window_q
=
sliding_window_q
if
sliding_window_q
>
0
else
None
,
sliding_window_k
=
sliding_window_k
if
sliding_window_k
>
0
else
None
,
)
# Compare outputs
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
2e-2
,
atol
=
2e-2
)
tests/kernels/attention/test_triton_unified_attention.py
View file @
7e63ef82
...
...
@@ -5,9 +5,10 @@
import
pytest
import
torch
from
vllm.attention.ops.triton_unified_attention
import
unified_attention
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
next_power_of_2
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.attention.ops.triton_unified_attention
import
unified_attention
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
)]
HEAD_SIZES
=
[
128
,
256
]
...
...
@@ -113,7 +114,7 @@ def test_triton_unified_attn(
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
...
...
tests/kernels/attention/untest_attention_selector.py
View file @
7e63ef82
...
...
@@ -6,11 +6,13 @@ from unittest.mock import patch
import
pytest
import
torch
from
vllm.
attention.selector
import
_cached_get_attn_backend
,
get_attn_backend
from
vllm.
config
import
AttentionConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.platforms
import
current_platform
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.selector
import
_cached_get_attn_backend
,
get_attn_backend
@
pytest
.
fixture
(
autouse
=
True
)
...
...
@@ -73,18 +75,18 @@ def generate_params():
@
pytest
.
mark
.
parametrize
(
"device, name, use_mla, block_size"
,
generate_params
())
def
test_
env
(
def
test_
backend_selection
(
device
:
str
,
name
:
str
,
use_mla
:
bool
,
block_size
:
int
,
monkeypatch
:
pytest
.
MonkeyPatch
,
):
"""Test attention backend selection with valid device-backend pairs."""
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
name
)
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
if
use_mla
else
"0"
)
# Create AttentionConfig with the specified backend
attention_config
=
AttentionConfig
(
backend
=
AttentionBackendEnum
[
name
]
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
with
set_current_vllm_config
(
vllm_config
):
if
device
==
"cpu"
:
with
patch
(
"vllm.platforms.current_platform"
,
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
block_size
)
...
...
@@ -180,7 +182,7 @@ def test_env(
expected
=
name
assert
backend
.
get_name
()
==
expected
elif
name
==
"FLASH_ATTN_MLA"
:
from
vllm.attention.
util
s.fa_utils
import
(
from
vllm.
v1.
attention.
backend
s.fa_utils
import
(
flash_attn_supports_mla
,
)
...
...
@@ -217,27 +219,32 @@ def test_env(
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"cuda"
])
def
test_fp32_fallback
(
device
:
str
):
"""Test attention backend selection with fp32."""
if
device
==
"cpu"
:
with
patch
(
"vllm.platforms.current_platform"
,
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
,
16
)
assert
backend
.
get_name
()
==
"CPU_ATTN"
# Use default config (no backend specified)
vllm_config
=
VllmConfig
()
elif
device
==
"cuda"
:
with
patch
(
"vllm.platforms.current_platform"
,
CudaPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
,
16
)
assert
backend
.
get_name
()
==
"FLEX_ATTENTION"
with
set_current_vllm_config
(
vllm_config
):
if
device
==
"cpu"
:
with
patch
(
"vllm.platforms.current_platform"
,
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
,
16
)
assert
backend
.
get_name
()
==
"CPU_ATTN"
elif
device
==
"cuda"
:
with
patch
(
"vllm.platforms.current_platform"
,
CudaPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
,
16
)
assert
backend
.
get_name
()
==
"FLEX_ATTENTION"
def
test_flash_attn
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test FlashAttn validation."""
pytest
.
skip
(
"Skipping as current backend selector does not "
"handle fallbacks when a backend is
set via env var
."
"handle fallbacks when a backend is
explicitly set
."
)
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLASH_ATTN"
)
attention_config
=
AttentionConfig
(
backend
=
AttentionBackendEnum
.
FLASH_ATTN
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
with
set_current_vllm_config
(
vllm_config
):
# Unsupported CUDA arch
monkeypatch
.
setattr
(
torch
.
cuda
,
"get_device_capability"
,
lambda
_
=
None
:
(
7
,
5
))
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
)
...
...
@@ -277,15 +284,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
assert
backend
.
get_name
()
!=
"FLASH_ATTN"
def
test_invalid_
env
(
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_invalid_
backend
(
):
"""Test that invalid attention backend names raise ValueError."""
with
(
monkeypatch
.
context
()
as
m
,
patch
(
"vllm.platforms.current_platform"
,
CudaPlatform
()),
pytest
.
raises
(
ValueError
),
):
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"INVALID"
)
# Should raise ValueError for invalid backend
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
get_attn_backend
(
32
,
torch
.
float16
,
None
,
16
)
assert
"Invalid value 'INVALID'"
in
str
(
exc_info
.
value
)
# Invalid backend name should raise ValueError when creating enum
AttentionConfig
(
backend
=
AttentionBackendEnum
[
"INVALID"
])
tests/kernels/attention/untest_flashinfer.py
View file @
7e63ef82
...
...
@@ -5,6 +5,7 @@
import
pytest
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
try
:
import
flashinfer
...
...
@@ -101,7 +102,7 @@ def test_flashinfer_decode_with_paged_kv(
sliding_window
:
int
|
None
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
...
...
@@ -196,7 +197,7 @@ def test_flashinfer_prefill_with_paged_kv(
sliding_window
:
int
|
None
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
...
...
@@ -299,7 +300,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
)
->
None
:
pytest
.
skip
(
"TODO: fix the accuracy issue"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
...
...
@@ -409,7 +410,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
)
->
None
:
# test doesn't work for num_heads = (16,16)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
...
...
tests/kernels/core/test_activation.py
View file @
7e63ef82
...
...
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.activation import (
SiluAndMul
,
SwigluOAIAndMul
,
)
from
vllm.
platforms
import
current_platform
from
vllm.
utils.torch_utils
import
set_random_seed
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
...
...
@@ -45,6 +45,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_act_and_mul
(
default_vllm_config
,
activation
:
str
,
num_tokens
:
int
,
d
:
int
,
...
...
@@ -52,7 +53,7 @@ def test_act_and_mul(
seed
:
int
,
device
:
str
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
)
if
activation
==
"silu_and_mul"
:
...
...
@@ -122,6 +123,7 @@ def test_act_and_mul(
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_activation
(
default_vllm_config
,
activation
:
type
[
torch
.
nn
.
Module
],
num_tokens
:
int
,
d
:
int
,
...
...
@@ -129,7 +131,7 @@ def test_activation(
seed
:
int
,
device
:
str
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
)
layer
=
activation
[
0
]()
...
...
tests/kernels/core/test_fused_qk_norm_rope.py
View file @
7e63ef82
...
...
@@ -8,11 +8,13 @@ from tests.kernels.utils import opcheck
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float16
]
IS_NEOX
=
[
True
,
False
]
EPS_VALUES
=
[
1e-5
,
1e-6
]
SEEDS
=
[
13
]
PARTIAL_ROPE
=
[
True
,
False
]
CUDA_DEVICES
=
[
"cuda:0"
]
...
...
@@ -52,16 +54,19 @@ def _apply_qk_norm_rope(
@
pytest
.
mark
.
parametrize
(
"is_neox"
,
IS_NEOX
)
@
pytest
.
mark
.
parametrize
(
"eps"
,
EPS_VALUES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"rotary_ratio"
,
[
1.0
,
0.5
,
0.25
])
@
torch
.
inference_mode
()
def
test_fused_qk_norm_rope_matches_reference
(
default_vllm_config
,
device
:
str
,
dtype
:
torch
.
dtype
,
is_neox
:
bool
,
eps
:
float
,
seed
:
int
,
rotary_ratio
:
float
,
):
torch
.
set_default_device
(
device
)
current_platform
.
seed_everything
(
seed
)
set_random_seed
(
seed
)
num_heads
,
num_kv_heads
,
head_dim
=
16
,
4
,
128
num_tokens
=
4
...
...
@@ -76,10 +81,10 @@ def test_fused_qk_norm_rope_matches_reference(
k_norm
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
q_weight
=
q_norm
.
weight
.
data
k_weight
=
k_norm
.
weight
.
data
rotary_dim
=
int
(
head_dim
*
rotary_ratio
)
rope
=
RotaryEmbedding
(
head_size
=
head_dim
,
rotary_dim
=
head
_dim
,
rotary_dim
=
rotary
_dim
,
max_position_embeddings
=
4096
,
base
=
10000.0
,
is_neox_style
=
is_neox
,
...
...
tests/kernels/core/test_fused_quant_layernorm.py
View file @
7e63ef82
...
...
@@ -147,6 +147,7 @@ def ops_impl(
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_rms_norm
(
default_vllm_config
,
num_tokens
:
int
,
hidden_size
:
int
,
add_residual
:
bool
,
...
...
tests/kernels/core/test_layernorm.py
View file @
7e63ef82
...
...
@@ -7,7 +7,7 @@ import torch
from
tests.kernels.quant_utils
import
FP8_DTYPE
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.
platforms
import
current_platform
from
vllm.
utils.torch_utils
import
set_random_seed
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
4096
]
# Arbitrary values for testing
...
...
@@ -26,6 +26,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e
@
pytest
.
mark
.
parametrize
(
"strided_input"
,
[
False
,
True
])
@
torch
.
inference_mode
()
def
test_rms_norm
(
default_vllm_config
,
num_tokens
:
int
,
hidden_size
:
int
,
add_residual
:
bool
,
...
...
@@ -34,7 +35,7 @@ def test_rms_norm(
device
:
str
,
strided_input
:
bool
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
=
dtype
)
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
...
...
@@ -70,6 +71,80 @@ def test_rms_norm(
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"add_residual"
,
ADD_RESIDUAL
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"quant_scale"
,
[
0.01
,
1.0
,
10.0
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"strided_input"
,
[
False
,
True
])
def
test_fused_rms_norm_quant
(
num_tokens
:
int
,
hidden_size
:
int
,
add_residual
:
bool
,
dtype
:
torch
.
dtype
,
quant_scale
:
float
,
seed
:
int
,
device
:
str
,
strided_input
:
bool
,
)
->
None
:
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
weight
=
torch
.
empty
(
hidden_size
,
dtype
=
dtype
).
normal_
(
mean
=
1.0
,
std
=
0.1
)
scale
=
1
/
(
2
*
hidden_size
)
last_dim
=
2
*
hidden_size
if
strided_input
else
hidden_size
x_base
=
torch
.
randn
(
num_tokens
,
last_dim
,
dtype
=
dtype
)
x
=
x_base
[...,
:
hidden_size
]
assert
x
.
is_contiguous
()
!=
strided_input
x
*=
scale
if
add_residual
:
residual
=
torch
.
randn_like
(
x
)
*
scale
residual_fused
=
residual
.
clone
()
else
:
residual
=
residual_fused
=
None
out_norm
=
torch
.
empty_like
(
x
)
out_quant
=
torch
.
empty_like
(
x
,
dtype
=
FP8_DTYPE
)
out_quant_fused
=
torch
.
empty_like
(
out_quant
)
quant_scale_t
=
torch
.
tensor
(
quant_scale
,
dtype
=
torch
.
float32
)
if
add_residual
:
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
(
out_quant_fused
,
x
,
residual_fused
,
weight
,
quant_scale_t
,
1e-6
)
# Unfused kernel is in-place so it goes second
# Also use a separate clone of x to avoid modifying the input
x_unfused_base
=
x_base
.
clone
()
x_unfused
=
x_unfused_base
[...,
:
hidden_size
]
assert
x_unfused
.
is_contiguous
()
!=
strided_input
torch
.
ops
.
_C
.
fused_add_rms_norm
(
x_unfused
,
residual
,
weight
,
1e-6
)
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
out_quant
,
x_unfused
.
contiguous
(),
quant_scale_t
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
residual_fused
,
residual
,
atol
=
1e-2
,
rtol
=
1e-2
)
opcheck
(
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
,
(
out_quant_fused
,
x
,
residual_fused
,
weight
,
quant_scale_t
,
1e-6
),
)
else
:
torch
.
ops
.
_C
.
rms_norm_static_fp8_quant
(
out_quant_fused
,
x
,
weight
,
quant_scale_t
,
1e-6
)
torch
.
ops
.
_C
.
rms_norm
(
out_norm
,
x
,
weight
,
1e-6
)
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
out_quant
,
out_norm
,
quant_scale_t
)
opcheck
(
torch
.
ops
.
_C
.
rms_norm_static_fp8_quant
,
(
out_quant_fused
,
x
,
weight
,
quant_scale_t
,
1e-6
),
)
# @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
...
...
Prev
1
…
18
19
20
21
22
23
24
25
26
…
35
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment