Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7e63ef82
Commit
7e63ef82
authored
Jan 21, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.14.0' into v0.14.0-dev
parents
8cbcac5d
b17039bc
Changes
681
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
531 additions
and
106 deletions
+531
-106
tests/kernels/attention/test_cpu_attn.py
tests/kernels/attention/test_cpu_attn.py
+2
-1
tests/kernels/attention/test_flash_attn.py
tests/kernels/attention/test_flash_attn.py
+2
-1
tests/kernels/attention/test_flashinfer_trtllm_attention.py
tests/kernels/attention/test_flashinfer_trtllm_attention.py
+3
-2
tests/kernels/attention/test_flashmla.py
tests/kernels/attention/test_flashmla.py
+2
-2
tests/kernels/attention/test_flashmla_sparse.py
tests/kernels/attention/test_flashmla_sparse.py
+3
-3
tests/kernels/attention/test_lightning_attn.py
tests/kernels/attention/test_lightning_attn.py
+4
-4
tests/kernels/attention/test_merge_attn_states.py
tests/kernels/attention/test_merge_attn_states.py
+2
-2
tests/kernels/attention/test_mha_attn.py
tests/kernels/attention/test_mha_attn.py
+74
-15
tests/kernels/attention/test_pack_unpack_triton.py
tests/kernels/attention/test_pack_unpack_triton.py
+1
-1
tests/kernels/attention/test_prefix_prefill.py
tests/kernels/attention/test_prefix_prefill.py
+40
-7
tests/kernels/attention/test_rocm_attention_selector.py
tests/kernels/attention/test_rocm_attention_selector.py
+44
-26
tests/kernels/attention/test_triton_decode_attention.py
tests/kernels/attention/test_triton_decode_attention.py
+1
-1
tests/kernels/attention/test_triton_prefill_attention.py
tests/kernels/attention/test_triton_prefill_attention.py
+225
-0
tests/kernels/attention/test_triton_unified_attention.py
tests/kernels/attention/test_triton_unified_attention.py
+3
-2
tests/kernels/attention/untest_attention_selector.py
tests/kernels/attention/untest_attention_selector.py
+29
-27
tests/kernels/attention/untest_flashinfer.py
tests/kernels/attention/untest_flashinfer.py
+5
-4
tests/kernels/core/test_activation.py
tests/kernels/core/test_activation.py
+5
-3
tests/kernels/core/test_fused_qk_norm_rope.py
tests/kernels/core/test_fused_qk_norm_rope.py
+8
-3
tests/kernels/core/test_fused_quant_layernorm.py
tests/kernels/core/test_fused_quant_layernorm.py
+1
-0
tests/kernels/core/test_layernorm.py
tests/kernels/core/test_layernorm.py
+77
-2
No files found.
Too many changes to show.
To preserve performance only
681 of 681+
files are displayed.
Plain diff
Email patch
tests/kernels/attention/test_cpu_attn.py
View file @
7e63ef82
...
@@ -8,6 +8,7 @@ import pytest
...
@@ -8,6 +8,7 @@ import pytest
import
torch
import
torch
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.attention.backends.cpu_attn
import
_get_attn_isa
from
vllm.v1.attention.backends.cpu_attn
import
_get_attn_isa
if
not
current_platform
.
is_cpu
():
if
not
current_platform
.
is_cpu
():
...
@@ -190,7 +191,7 @@ def varlen_with_paged_kv(
...
@@ -190,7 +191,7 @@ def varlen_with_paged_kv(
use_sink
:
bool
,
use_sink
:
bool
,
isa
:
str
,
isa
:
str
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
num_seqs
=
len
(
seq_lens
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
...
...
tests/kernels/attention/test_flash_attn.py
View file @
7e63ef82
...
@@ -6,6 +6,7 @@ import pytest
...
@@ -6,6 +6,7 @@ import pytest
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
try
:
try
:
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
...
@@ -132,7 +133,7 @@ def test_varlen_with_paged_kv(
...
@@ -132,7 +133,7 @@ def test_varlen_with_paged_kv(
"Flash attention with quantized inputs is only "
"Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type"
"supported on version 3 with bfloat16 base type"
)
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
num_seqs
=
len
(
seq_lens
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
...
...
tests/kernels/attention/test_flashinfer_trtllm_attention.py
View file @
7e63ef82
...
@@ -10,6 +10,7 @@ from tests.kernels.quantization.nvfp4_utils import (
...
@@ -10,6 +10,7 @@ from tests.kernels.quantization.nvfp4_utils import (
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.torch_utils
import
set_random_seed
if
not
current_platform
.
is_device_capability_family
(
100
):
if
not
current_platform
.
is_device_capability_family
(
100
):
pytest
.
skip
(
pytest
.
skip
(
...
@@ -80,7 +81,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
...
@@ -80,7 +81,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
has_sinks
:
bool
,
has_sinks
:
bool
,
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
42
)
set_random_seed
(
42
)
q_quant_dtype
,
kv_quant_dtype
,
o_quant_dtype
=
quant_dtypes
q_quant_dtype
,
kv_quant_dtype
,
o_quant_dtype
=
quant_dtypes
q_quant_dtype
=
q_quant_dtype
or
dtype
q_quant_dtype
=
q_quant_dtype
or
dtype
...
@@ -279,7 +280,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
...
@@ -279,7 +280,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
has_sinks
:
bool
,
has_sinks
:
bool
,
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
42
)
set_random_seed
(
42
)
q_quant_dtype
,
kv_quant_dtype
,
o_quant_dtype
=
quant_dtypes
q_quant_dtype
,
kv_quant_dtype
,
o_quant_dtype
=
quant_dtypes
q_quant_dtype
=
q_quant_dtype
or
dtype
q_quant_dtype
=
q_quant_dtype
or
dtype
...
...
tests/kernels/attention/test_flashmla.py
View file @
7e63ef82
...
@@ -7,12 +7,12 @@ import random
...
@@ -7,12 +7,12 @@ import random
import
pytest
import
pytest
import
torch
import
torch
from
vllm.attention.ops.flashmla
import
(
from
vllm.triton_utils
import
triton
from
vllm.v1.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
flash_mla_with_kvcache
,
get_mla_metadata
,
get_mla_metadata
,
is_flashmla_dense_supported
,
is_flashmla_dense_supported
,
)
)
from
vllm.triton_utils
import
triton
def
cal_diff
(
def
cal_diff
(
...
...
tests/kernels/attention/test_flashmla_sparse.py
View file @
7e63ef82
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
def
test_sparse_flashmla_metadata_smoke
():
def
test_sparse_flashmla_metadata_smoke
():
import
vllm.attention.ops.flashmla
as
fm
import
vllm.
v1.
attention.ops.flashmla
as
fm
ok
,
reason
=
fm
.
is_flashmla_sparse_supported
()
ok
,
reason
=
fm
.
is_flashmla_sparse_supported
()
if
not
ok
:
if
not
ok
:
...
@@ -34,7 +34,7 @@ def test_sparse_flashmla_metadata_smoke():
...
@@ -34,7 +34,7 @@ def test_sparse_flashmla_metadata_smoke():
def
test_sparse_flashmla_decode_smoke
():
def
test_sparse_flashmla_decode_smoke
():
import
vllm.attention.ops.flashmla
as
fm
import
vllm.
v1.
attention.ops.flashmla
as
fm
ok
,
reason
=
fm
.
is_flashmla_sparse_supported
()
ok
,
reason
=
fm
.
is_flashmla_sparse_supported
()
if
not
ok
:
if
not
ok
:
...
@@ -97,7 +97,7 @@ def test_sparse_flashmla_decode_smoke():
...
@@ -97,7 +97,7 @@ def test_sparse_flashmla_decode_smoke():
def
test_sparse_flashmla_prefill_smoke
():
def
test_sparse_flashmla_prefill_smoke
():
import
vllm.attention.ops.flashmla
as
fm
import
vllm.
v1.
attention.ops.flashmla
as
fm
ok
,
reason
=
fm
.
is_flashmla_sparse_supported
()
ok
,
reason
=
fm
.
is_flashmla_sparse_supported
()
if
not
ok
:
if
not
ok
:
...
...
tests/kernels/attention/test_lightning_attn.py
View file @
7e63ef82
...
@@ -5,7 +5,7 @@ import pytest
...
@@ -5,7 +5,7 @@ import pytest
import
torch
import
torch
from
vllm.model_executor.layers.lightning_attn
import
linear_decode_forward_triton
from
vllm.model_executor.layers.lightning_attn
import
linear_decode_forward_triton
from
vllm.
platforms
import
current_platform
from
vllm.
utils.torch_utils
import
set_random_seed
NUM_HEADS
=
[
4
,
8
]
NUM_HEADS
=
[
4
,
8
]
HEAD_SIZES
=
[
64
]
HEAD_SIZES
=
[
64
]
...
@@ -124,7 +124,7 @@ def test_linear_decode_forward_triton(
...
@@ -124,7 +124,7 @@ def test_linear_decode_forward_triton(
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed_all
(
42
)
torch
.
cuda
.
manual_seed_all
(
42
)
current_platform
.
seed_everything
(
42
)
set_random_seed
(
42
)
base
=
0.01
base
=
0.01
q
=
base
*
torch
.
randn
(
batch_size
,
num_heads
,
1
,
head_size
,
dtype
=
dtype
)
q
=
base
*
torch
.
randn
(
batch_size
,
num_heads
,
1
,
head_size
,
dtype
=
dtype
)
k
=
base
*
torch
.
randn
(
batch_size
,
num_heads
,
1
,
head_size
,
dtype
=
dtype
)
k
=
base
*
torch
.
randn
(
batch_size
,
num_heads
,
1
,
head_size
,
dtype
=
dtype
)
...
@@ -167,7 +167,7 @@ def test_linear_decode_forward_triton_with_padding(
...
@@ -167,7 +167,7 @@ def test_linear_decode_forward_triton_with_padding(
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed_all
(
42
)
torch
.
cuda
.
manual_seed_all
(
42
)
current_platform
.
seed_everything
(
42
)
set_random_seed
(
42
)
batch_size
=
4
batch_size
=
4
base
=
0.01
base
=
0.01
...
@@ -231,7 +231,7 @@ def test_lightning_attention_reference(
...
@@ -231,7 +231,7 @@ def test_lightning_attention_reference(
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed_all
(
42
)
torch
.
cuda
.
manual_seed_all
(
42
)
current_platform
.
seed_everything
(
42
)
set_random_seed
(
42
)
base
=
0.01
base
=
0.01
q
=
base
*
torch
.
randn
(
batch_size
,
num_heads
,
seq_len
,
head_size
,
dtype
=
dtype
)
q
=
base
*
torch
.
randn
(
batch_size
,
num_heads
,
seq_len
,
head_size
,
dtype
=
dtype
)
...
...
tests/kernels/attention/test_merge_attn_states.py
View file @
7e63ef82
...
@@ -5,10 +5,10 @@ import pytest
...
@@ -5,10 +5,10 @@ import pytest
import
torch
import
torch
from
vllm._custom_ops
import
merge_attn_states
as
merge_attn_states_cuda
from
vllm._custom_ops
import
merge_attn_states
as
merge_attn_states_cuda
from
vllm.attention.ops.triton_merge_attn_states
import
(
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.ops.triton_merge_attn_states
import
(
merge_attn_states
as
merge_attn_states_triton
,
merge_attn_states
as
merge_attn_states_triton
,
)
)
from
vllm.platforms
import
current_platform
# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
...
...
tests/kernels/attention/test_mha_attn.py
View file @
7e63ef82
...
@@ -3,21 +3,23 @@
...
@@ -3,21 +3,23 @@
"""
"""
Test:
Test:
* Tests for M
ultiHead
Attention layer
* Tests for M
MEncoder
Attention layer
"""
"""
import
itertools
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
pytest
import
torch
import
torch
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.model_executor.layers.attention.mm_encoder_attention
import
MMEncoderAttention
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.attention.selector
import
_cached_get_attn_backend
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.selector
import
_cached_get_attn_backend
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
...
@@ -34,7 +36,7 @@ if current_platform.is_rocm():
...
@@ -34,7 +36,7 @@ if current_platform.is_rocm():
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_mha_attn_platform
(
device
:
str
):
def
test_mha_attn_platform
(
default_vllm_config
,
device
:
str
):
"""
"""
Test the attention selector between different platform and device.
Test the attention selector between different platform and device.
"""
"""
...
@@ -42,35 +44,31 @@ def test_mha_attn_platform(device: str):
...
@@ -42,35 +44,31 @@ def test_mha_attn_platform(device: str):
if
device
==
"cpu"
:
if
device
==
"cpu"
:
with
(
with
(
patch
(
"vllm.attention.layer.current_platform"
,
CpuPlatform
()),
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CpuPlatform
()),
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CpuPlatform
()),
):
):
attn
=
M
ultiHead
Attention
(
16
,
64
,
scale
=
1
)
attn
=
M
MEncoder
Attention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
TORCH_SDPA
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
TORCH_SDPA
elif
device
==
"hip"
:
elif
device
==
"hip"
:
with
(
with
(
patch
(
"vllm.attention.layer.current_platform"
,
RocmPlatform
()),
patch
(
"vllm.model_executor.models.vision.current_platform"
,
RocmPlatform
()),
patch
(
"vllm.model_executor.models.vision.current_platform"
,
RocmPlatform
()),
):
):
attn
=
M
ultiHead
Attention
(
16
,
64
,
scale
=
1
)
attn
=
M
MEncoder
Attention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
else
:
else
:
# Test CUDA with head_size=64 (divisible by 32)
# Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention
# - should use vLLM's FlashAttention
with
(
with
(
patch
(
"vllm.attention.layer.current_platform"
,
CudaPlatform
()),
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CudaPlatform
()),
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CudaPlatform
()),
):
):
attn
=
M
ultiHead
Attention
(
16
,
64
,
scale
=
1
)
attn
=
M
MEncoder
Attention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
# Test CUDA with head_size=72 (not divisible by 32)
# Test CUDA with head_size=72 (not divisible by 32)
# - should use vLLM's FlashAttention
# - should use vLLM's FlashAttention
with
(
with
(
patch
(
"vllm.attention.layer.current_platform"
,
CudaPlatform
()),
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CudaPlatform
()),
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CudaPlatform
()),
):
):
attn
=
M
ultiHead
Attention
(
16
,
72
,
scale
=
1
)
attn
=
M
MEncoder
Attention
(
16
,
72
,
scale
=
1
)
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
...
@@ -94,6 +92,10 @@ def ref_attention(
...
@@ -94,6 +92,10 @@ def ref_attention(
BATCH_SIZES
=
[
1
,
16
]
BATCH_SIZES
=
[
1
,
16
]
SEQ_LENS
=
[
1
]
SEQ_LENS
=
[
1
]
VAR_SEQ_LENS
=
[
[
2
,
2
],
[
2
,
3
,
4
],
]
NUM_HEADS
=
[
1
,
16
]
NUM_HEADS
=
[
1
,
16
]
NUM_KV_HEADS
=
[
1
]
NUM_KV_HEADS
=
[
1
]
HEAD_SIZES
=
[
64
,
80
]
HEAD_SIZES
=
[
64
,
80
]
...
@@ -114,6 +116,7 @@ CUDA_DEVICES = ["cuda"]
...
@@ -114,6 +116,7 @@ CUDA_DEVICES = ["cuda"]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_mha_attn_forward
(
def
test_mha_attn_forward
(
default_vllm_config
,
batch_size
:
int
,
batch_size
:
int
,
seq_len
:
int
,
seq_len
:
int
,
num_heads
:
int
,
num_heads
:
int
,
...
@@ -122,7 +125,7 @@ def test_mha_attn_forward(
...
@@ -122,7 +125,7 @@ def test_mha_attn_forward(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
,
device
:
str
,
):
):
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_dtype
(
dtype
)
...
@@ -130,7 +133,7 @@ def test_mha_attn_forward(
...
@@ -130,7 +133,7 @@ def test_mha_attn_forward(
k
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
)
k
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
)
v
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
)
v
=
torch
.
randn
(
batch_size
,
seq_len
,
num_kv_heads
*
head_size
)
scale
=
1.0
/
head_size
**
0.5
scale
=
1.0
/
head_size
**
0.5
attn
=
M
ultiHead
Attention
(
attn
=
M
MEncoder
Attention
(
num_heads
,
head_size
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
num_heads
,
head_size
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
)
)
output
=
attn
(
q
,
k
,
v
)
output
=
attn
(
q
,
k
,
v
)
...
@@ -151,3 +154,59 @@ def test_mha_attn_forward(
...
@@ -151,3 +154,59 @@ def test_mha_attn_forward(
scale
=
scale
,
scale
=
scale
,
).
reshape
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
).
reshape
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
torch
.
testing
.
assert_close
(
output
,
ref_output
)
torch
.
testing
.
assert_close
(
output
,
ref_output
)
@
pytest
.
mark
.
parametrize
(
"var_seq_len"
,
VAR_SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_kv_heads"
,
NUM_KV_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_mha_attn_varlen_forward
(
default_vllm_config
,
var_seq_len
:
list
[
int
],
num_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
):
set_random_seed
(
0
)
torch
.
set_default_device
(
device
)
torch
.
set_default_dtype
(
dtype
)
q
=
torch
.
randn
(
1
,
sum
(
var_seq_len
),
num_heads
,
head_size
)
k
=
torch
.
randn
(
1
,
sum
(
var_seq_len
),
num_kv_heads
,
head_size
)
v
=
torch
.
randn
(
1
,
sum
(
var_seq_len
),
num_kv_heads
,
head_size
)
cu_seqlens
=
torch
.
tensor
(
[
0
]
+
list
(
itertools
.
accumulate
(
var_seq_len
)),
dtype
=
torch
.
int32
)
scale
=
1.0
/
head_size
**
0.5
attn
=
MMEncoderAttention
(
num_heads
,
head_size
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
)
output
=
attn
(
q
,
k
,
v
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
torch
.
tensor
(
max
(
var_seq_len
))
)
assert
num_heads
%
num_kv_heads
==
0
num_queries_per_kv
=
num_heads
//
num_kv_heads
if
num_queries_per_kv
>
1
:
k
=
torch
.
repeat_interleave
(
k
,
num_queries_per_kv
,
dim
=
2
)
v
=
torch
.
repeat_interleave
(
v
,
num_queries_per_kv
,
dim
=
2
)
ref_output
=
[]
for
q_i
,
k_i
,
v_i
in
zip
(
torch
.
split
(
q
,
var_seq_len
,
dim
=
1
),
torch
.
split
(
k
,
var_seq_len
,
dim
=
1
),
torch
.
split
(
v
,
var_seq_len
,
dim
=
1
),
):
output_i
=
ref_attention
(
q_i
,
k_i
,
v_i
,
scale
=
scale
,
)
ref_output
.
append
(
output_i
)
ref_output
=
torch
.
cat
(
ref_output
,
dim
=
1
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
)
tests/kernels/attention/test_pack_unpack_triton.py
View file @
7e63ef82
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
import
torch
import
torch
from
torch.testing
import
assert_close
from
torch.testing
import
assert_close
from
vllm.attention.ops.common
import
pack_seq_triton
,
unpack_seq_triton
from
vllm.
v1.
attention.ops.common
import
pack_seq_triton
,
unpack_seq_triton
def
test_pack_seq_basic_fp8
():
def
test_pack_seq_basic_fp8
():
...
...
tests/kernels/attention/test_prefix_prefill.py
View file @
7e63ef82
...
@@ -10,10 +10,12 @@ import pytest
...
@@ -10,10 +10,12 @@ import pytest
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
vllm.attention.ops.chunked_prefill_paged_decode
import
chunked_prefill_paged_decode
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
set_random_seed
from
vllm.v1.attention.ops.chunked_prefill_paged_decode
import
(
chunked_prefill_paged_decode
,
)
from
vllm.v1.attention.ops.prefix_prefill
import
context_attention_fwd
if
not
current_platform
.
is_rocm
():
if
not
current_platform
.
is_rocm
():
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
...
@@ -117,6 +119,7 @@ def test_contexted_kv_attention(
...
@@ -117,6 +119,7 @@ def test_contexted_kv_attention(
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
device
:
str
,
device
:
str
,
op
:
Callable
,
op
:
Callable
,
block_size
:
int
=
32
,
)
->
None
:
)
->
None
:
if
"fp8"
in
kv_cache_dtype
and
not
current_platform
.
has_device_capability
(
89
):
if
"fp8"
in
kv_cache_dtype
and
not
current_platform
.
has_device_capability
(
89
):
pytest
.
skip
(
pytest
.
skip
(
...
@@ -130,7 +133,7 @@ def test_contexted_kv_attention(
...
@@ -130,7 +133,7 @@ def test_contexted_kv_attention(
):
):
pytest
.
skip
(
"ROCm custom paged attention does not support fp8_e5m2 KV cache"
)
pytest
.
skip
(
"ROCm custom paged attention does not support fp8_e5m2 KV cache"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
# Need this, otherwise when we capture the graph the process
# Need this, otherwise when we capture the graph the process
...
@@ -143,7 +146,6 @@ def test_contexted_kv_attention(
...
@@ -143,7 +146,6 @@ def test_contexted_kv_attention(
MAX_CTX_LEN
=
1024
MAX_CTX_LEN
=
1024
BS
=
10
BS
=
10
cache_size
=
640
cache_size
=
640
block_size
=
32
max_block_per_request
=
64
max_block_per_request
=
64
query_lens
=
[
random
.
randint
(
16
,
MAX_SEQ_LEN
)
for
_
in
range
(
BS
)]
query_lens
=
[
random
.
randint
(
16
,
MAX_SEQ_LEN
)
for
_
in
range
(
BS
)]
# ensure one sequence in batch is a decode
# ensure one sequence in batch is a decode
...
@@ -338,6 +340,7 @@ def test_contexted_kv_attention_alibi(
...
@@ -338,6 +340,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
device
:
str
,
device
:
str
,
op
:
Callable
,
op
:
Callable
,
block_size
:
int
=
32
,
)
->
None
:
)
->
None
:
if
"fp8"
in
kv_cache_dtype
and
not
current_platform
.
has_device_capability
(
89
):
if
"fp8"
in
kv_cache_dtype
and
not
current_platform
.
has_device_capability
(
89
):
pytest
.
skip
(
pytest
.
skip
(
...
@@ -351,7 +354,7 @@ def test_contexted_kv_attention_alibi(
...
@@ -351,7 +354,7 @@ def test_contexted_kv_attention_alibi(
):
):
pytest
.
skip
(
"ROCm custom paged attention does not support fp8_e5m2 KV cache"
)
pytest
.
skip
(
"ROCm custom paged attention does not support fp8_e5m2 KV cache"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
# Need this, otherwise when we capture the graph the process
# Need this, otherwise when we capture the graph the process
...
@@ -390,7 +393,6 @@ def test_contexted_kv_attention_alibi(
...
@@ -390,7 +393,6 @@ def test_contexted_kv_attention_alibi(
MAX_CTX_LEN
=
1024
MAX_CTX_LEN
=
1024
BS
=
10
BS
=
10
cache_size
=
640
cache_size
=
640
block_size
=
32
max_block_per_request
=
64
max_block_per_request
=
64
query_lens
=
[
random
.
randint
(
16
,
MAX_SEQ_LEN
)
for
_
in
range
(
BS
)]
query_lens
=
[
random
.
randint
(
16
,
MAX_SEQ_LEN
)
for
_
in
range
(
BS
)]
ctx_lens
=
[
random
.
randint
(
16
,
MAX_CTX_LEN
)
for
_
in
range
(
BS
)]
ctx_lens
=
[
random
.
randint
(
16
,
MAX_CTX_LEN
)
for
_
in
range
(
BS
)]
...
@@ -643,3 +645,34 @@ def test_contexted_kv_attention_alibi_f32(
...
@@ -643,3 +645,34 @@ def test_contexted_kv_attention_alibi_f32(
test_contexted_kv_attention_alibi
(
test_contexted_kv_attention_alibi
(
num_heads
,
num_queries_per_kv
,
head_size
,
dtype
,
kv_cache_dtype
,
device
,
op
num_heads
,
num_queries_per_kv
,
head_size
,
dtype
,
kv_cache_dtype
,
device
,
op
)
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"op"
,
OPS
)
@
torch
.
inference_mode
()
def
test_qwen3_nonstandard_block_size
(
head_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
op
:
Callable
,
)
->
None
:
"""
A separate test function specifically added
for Qwen3-Next-80B (Block Size 544).
"""
if
not
current_platform
.
is_rocm
():
pytest
.
skip
(
"544 block size optimization is only for ROCm."
)
test_contexted_kv_attention
(
num_heads
=
64
,
num_queries_per_kv
=
1
,
head_size
=
head_size
,
block_size
=
544
,
sliding_window
=
0
,
dtype
=
dtype
,
kv_cache_dtype
=
"auto"
,
device
=
device
,
op
=
op
,
)
tests/kernels/attention/test_rocm_attention_selector.py
View file @
7e63ef82
...
@@ -4,8 +4,10 @@
...
@@ -4,8 +4,10 @@
import
pytest
import
pytest
import
torch
import
torch
from
vllm.
attention.selector
import
_cached_get_attn_backend
,
get_attn_backend
from
vllm.
config
import
AttentionConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.selector
import
_cached_get_attn_backend
,
get_attn_backend
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
...
@@ -16,40 +18,56 @@ def clear_cache():
...
@@ -16,40 +18,56 @@ def clear_cache():
@
pytest
.
mark
.
skip
(
reason
=
"Skipped for now. Should be revisited."
)
@
pytest
.
mark
.
skip
(
reason
=
"Skipped for now. Should be revisited."
)
def
test_selector
(
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_selector
(
monkeypatch
:
pytest
.
MonkeyPatch
):
with
monkeypatch
.
context
()
as
m
:
# Set the current platform to ROCm using monkeypatch
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"ROCM_ATTN"
)
m
onkeypatch
.
setattr
(
"vllm.v1.attention.selector.current_platform"
,
RocmPlatform
()
)
# Set the current platform to ROCm using monkeypatch
# Test standard ROCm attention
monkeypatch
.
setattr
(
"vllm.attention.selector.current_platform"
,
RocmPlatform
())
attention_config
=
AttentionConfig
(
backend
=
AttentionBackendEnum
.
ROCM_ATTN
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
# Test standard ROCm attention
with
set_current_vllm_config
(
vllm_config
):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
"ROCM_FLASH"
or
backend
.
get_name
()
==
"TRITON_ATTN"
assert
backend
.
get_name
()
==
"ROCM_FLASH"
or
backend
.
get_name
()
==
"TRITON_ATTN"
# MLA test for deepseek related
# MLA test for deepseek related
# Change the attention backend to triton MLA
attention_config
=
AttentionConfig
(
backend
=
AttentionBackendEnum
.
TRITON_MLA
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
# change the attention backend to triton MLA
with
set_current_vllm_config
(
vllm_config
):
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"TRITON_MLA"
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
use_mla
=
True
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
use_mla
=
True
)
assert
backend
.
get_name
()
==
"TRITON_MLA"
assert
backend
.
get_name
()
==
"TRITON_MLA"
# If attention backend is None
# If attention backend is None
# If use_mla is true
# If use_mla is true
# The selected backend is triton MLA
# The selected backend is triton MLA
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
""
)
attention_config
=
AttentionConfig
(
backend
=
None
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
with
set_current_vllm_config
(
vllm_config
):
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
use_mla
=
True
)
backend
=
get_attn_backend
(
576
,
torch
.
bfloat16
,
"auto"
,
16
,
False
,
use_mla
=
True
)
assert
backend
.
get_name
()
==
"TRITON_MLA"
assert
backend
.
get_name
()
==
"TRITON_MLA"
# change the attention backend to AITER MLA
# Change the attention backend to AITER MLA
# m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_MLA")
attention_config
=
AttentionConfig
(
backend
=
AttentionBackendEnum
.
ROCM_AITER_MLA
)
# backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
# assert backend.get_name() == "ROCM_AITER_MLA"
# with set_current_vllm_config(vllm_config):
# # If attention backend is None
# backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
# # If use_mla is true
# assert backend.get_name() == "ROCM_AITER_MLA"
# # If VLLM_ROCM_USE_AITER is enabled
# # The selected backend is ROCM_AITER_MLA
# # If attention backend is None
# m.setenv("VLLM_ATTENTION_BACKEND", "")
# # If use_mla is true
# m.setenv("VLLM_ROCM_USE_AITER", "1")
# # If VLLM_ROCM_USE_AITER is enabled
# backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
# # The selected backend is ROCM_AITER_MLA
# assert backend.get_name() == "ROCM_AITER_MLA"
# with monkeypatch.context() as m:
# m.setenv("VLLM_ROCM_USE_AITER", "1")
# attention_config = AttentionConfig(backend=None)
# vllm_config = VllmConfig(attention_config=attention_config)
# with set_current_vllm_config(vllm_config):
# backend = get_attn_backend(
# 576, torch.bfloat16, "auto", 1, False, use_mla=True
# )
# assert backend.get_name() == "ROCM_AITER_MLA"
tests/kernels/attention/test_triton_decode_attention.py
View file @
7e63ef82
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
import
pytest
import
pytest
import
torch
import
torch
from
vllm.attention.ops.triton_decode_attention
import
decode_attention_fwd
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.ops.triton_decode_attention
import
decode_attention_fwd
@
pytest
.
mark
.
parametrize
(
"B"
,
[
3
,
5
])
@
pytest
.
mark
.
parametrize
(
"B"
,
[
3
,
5
])
...
...
tests/kernels/attention/test_triton_prefill_attention.py
0 → 100644
View file @
7e63ef82
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
import
torch.nn.functional
as
F
from
vllm.v1.attention.ops.triton_prefill_attention
import
context_attention_fwd
def
ref_masked_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
is_causal
:
bool
=
True
,
sliding_window_q
:
int
|
None
=
None
,
sliding_window_k
:
int
|
None
=
None
,
)
->
torch
.
Tensor
:
"""Reference implementation using PyTorch SDPA."""
# q, k, v: [total_tokens, num_heads, head_dim]
# SDPA expects [batch, num_heads, seq_len, head_dim]
total_tokens
=
q
.
shape
[
0
]
# Add batch dimension and transpose
q
=
q
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# [1, num_heads, total_tokens, head_dim]
k
=
k
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# [1, num_heads, total_tokens, head_dim]
v
=
v
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# [1, num_heads, total_tokens, head_dim]
# Create attention mask if needed
attn_mask
=
None
use_causal
=
is_causal
# If we have sliding window or need custom masking, create explicit mask
sliding_window_q
=
sliding_window_q
if
sliding_window_q
is
not
None
else
0
sliding_window_k
=
sliding_window_k
if
sliding_window_k
is
not
None
else
0
if
(
sliding_window_q
>
0
)
or
(
sliding_window_k
>
0
):
# Position indices
pos_q
=
torch
.
arange
(
total_tokens
,
device
=
q
.
device
).
unsqueeze
(
1
)
pos_k
=
torch
.
arange
(
total_tokens
,
device
=
q
.
device
).
unsqueeze
(
0
)
# Start with valid mask (False = no masking)
mask
=
torch
.
ones
(
(
total_tokens
,
total_tokens
),
dtype
=
torch
.
bool
,
device
=
q
.
device
)
# Apply causal mask
if
is_causal
:
mask
=
mask
&
(
pos_q
>=
pos_k
)
# Apply sliding window masks
sliding_window_mask
=
torch
.
ones_like
(
mask
)
if
sliding_window_q
>
0
:
sliding_window_mask
&=
pos_q
-
pos_k
<=
sliding_window_q
if
sliding_window_k
>
0
:
sliding_window_mask
&=
pos_k
-
pos_q
<=
sliding_window_k
mask
=
mask
&
sliding_window_mask
attn_mask
=
torch
.
where
(
mask
,
0.0
,
float
(
"-inf"
)).
to
(
q
.
dtype
)
use_causal
=
False
# Don't use is_causal when providing explicit mask
# Use SDPA
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attn_mask
,
is_causal
=
use_causal
,
dropout_p
=
0.0
)
# Convert back to original shape: [total_tokens, num_heads, head_dim]
output
=
output
.
transpose
(
1
,
2
).
squeeze
(
0
)
return
output
@
pytest
.
mark
.
parametrize
(
"B"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"H_Q"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"H_KV"
,
[
32
,
8
])
@
pytest
.
mark
.
parametrize
(
"D"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"is_causal"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
def
test_context_attention
(
B
:
int
,
max_seq_len
:
int
,
H_Q
:
int
,
H_KV
:
int
,
D
:
int
,
is_causal
:
bool
,
dtype
:
torch
.
dtype
,
):
"""Test basic context attention without sliding window."""
torch
.
manual_seed
(
42
)
# Generate random sequence lengths for each batch
seq_lens
=
torch
.
randint
(
max_seq_len
//
2
,
max_seq_len
+
1
,
(
B
,),
device
=
"cuda"
)
total_tokens
=
seq_lens
.
sum
().
item
()
# Create batch start locations
b_start_loc
=
torch
.
zeros
(
B
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
# Create input tensors
q
=
torch
.
randn
(
total_tokens
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
k
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
v
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros_like
(
q
)
# Call Triton kernel
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
seq_lens
,
max_seq_len
,
is_causal
=
is_causal
,
sliding_window_q
=
None
,
sliding_window_k
=
None
,
)
# Compute reference output for each sequence in batch
o_ref
=
torch
.
zeros_like
(
q
)
for
i
in
range
(
B
):
start
=
b_start_loc
[
i
].
item
()
end
=
start
+
seq_lens
[
i
].
item
()
q_seq
=
q
[
start
:
end
]
k_seq
=
k
[
start
:
end
]
v_seq
=
v
[
start
:
end
]
# Expand KV heads if using GQA
if
H_Q
!=
H_KV
:
kv_group_num
=
H_Q
//
H_KV
k_seq
=
k_seq
.
repeat_interleave
(
kv_group_num
,
dim
=
1
)
v_seq
=
v_seq
.
repeat_interleave
(
kv_group_num
,
dim
=
1
)
o_ref
[
start
:
end
]
=
ref_masked_attention
(
q_seq
,
k_seq
,
v_seq
,
is_causal
=
is_causal
,
sliding_window_q
=
None
,
sliding_window_k
=
None
,
)
# Compare outputs
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
@
pytest
.
mark
.
parametrize
(
"B"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"H_Q"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"H_KV"
,
[
32
,
8
])
@
pytest
.
mark
.
parametrize
(
"D"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[(
32
,
32
),
(
32
,
0
),
(
0
,
32
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
def
test_context_attention_sliding_window
(
B
:
int
,
max_seq_len
:
int
,
H_Q
:
int
,
H_KV
:
int
,
D
:
int
,
sliding_window
:
tuple
[
int
,
int
],
dtype
:
torch
.
dtype
,
):
sliding_window_q
,
sliding_window_k
=
sliding_window
"""Test context attention with sliding window."""
torch
.
manual_seed
(
42
)
# Generate random sequence lengths for each batch
seq_lens
=
torch
.
randint
(
max_seq_len
//
2
,
max_seq_len
+
1
,
(
B
,),
device
=
"cuda"
)
total_tokens
=
seq_lens
.
sum
().
item
()
# Create batch start locations
b_start_loc
=
torch
.
zeros
(
B
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
# Create input tensors
q
=
torch
.
randn
(
total_tokens
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
k
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
v
=
torch
.
randn
(
total_tokens
,
H_KV
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros_like
(
q
)
# Call Triton kernel
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
seq_lens
,
max_seq_len
,
is_causal
=
False
,
sliding_window_q
=
sliding_window_q
,
sliding_window_k
=
sliding_window_k
,
)
# Compute reference output for each sequence in batch
o_ref
=
torch
.
zeros_like
(
q
)
for
i
in
range
(
B
):
start
=
b_start_loc
[
i
].
item
()
end
=
start
+
seq_lens
[
i
].
item
()
q_seq
=
q
[
start
:
end
]
k_seq
=
k
[
start
:
end
]
v_seq
=
v
[
start
:
end
]
# Expand KV heads if using GQA
if
H_Q
!=
H_KV
:
kv_group_num
=
H_Q
//
H_KV
k_seq
=
k_seq
.
repeat_interleave
(
kv_group_num
,
dim
=
1
)
v_seq
=
v_seq
.
repeat_interleave
(
kv_group_num
,
dim
=
1
)
o_ref
[
start
:
end
]
=
ref_masked_attention
(
q_seq
,
k_seq
,
v_seq
,
is_causal
=
False
,
sliding_window_q
=
sliding_window_q
if
sliding_window_q
>
0
else
None
,
sliding_window_k
=
sliding_window_k
if
sliding_window_k
>
0
else
None
,
)
# Compare outputs
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
2e-2
,
atol
=
2e-2
)
tests/kernels/attention/test_triton_unified_attention.py
View file @
7e63ef82
...
@@ -5,9 +5,10 @@
...
@@ -5,9 +5,10 @@
import
pytest
import
pytest
import
torch
import
torch
from
vllm.attention.ops.triton_unified_attention
import
unified_attention
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
next_power_of_2
from
vllm.utils.math_utils
import
next_power_of_2
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.attention.ops.triton_unified_attention
import
unified_attention
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
)]
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
)]
HEAD_SIZES
=
[
128
,
256
]
HEAD_SIZES
=
[
128
,
256
]
...
@@ -113,7 +114,7 @@ def test_triton_unified_attn(
...
@@ -113,7 +114,7 @@ def test_triton_unified_attn(
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
num_seqs
=
len
(
seq_lens
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
...
...
tests/kernels/attention/untest_attention_selector.py
View file @
7e63ef82
...
@@ -6,11 +6,13 @@ from unittest.mock import patch
...
@@ -6,11 +6,13 @@ from unittest.mock import patch
import
pytest
import
pytest
import
torch
import
torch
from
vllm.
attention.selector
import
_cached_get_attn_backend
,
get_attn_backend
from
vllm.
config
import
AttentionConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.selector
import
_cached_get_attn_backend
,
get_attn_backend
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
...
@@ -73,18 +75,18 @@ def generate_params():
...
@@ -73,18 +75,18 @@ def generate_params():
@
pytest
.
mark
.
parametrize
(
"device, name, use_mla, block_size"
,
generate_params
())
@
pytest
.
mark
.
parametrize
(
"device, name, use_mla, block_size"
,
generate_params
())
def
test_
env
(
def
test_
backend_selection
(
device
:
str
,
device
:
str
,
name
:
str
,
name
:
str
,
use_mla
:
bool
,
use_mla
:
bool
,
block_size
:
int
,
block_size
:
int
,
monkeypatch
:
pytest
.
MonkeyPatch
,
):
):
"""Test attention backend selection with valid device-backend pairs."""
"""Test attention backend selection with valid device-backend pairs."""
with
monkeypatch
.
context
()
as
m
:
# Create AttentionConfig with the specified backend
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
name
)
attention_config
=
AttentionConfig
(
backend
=
AttentionBackendEnum
[
name
]
)
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
if
use_mla
else
"0"
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
with
set_current_vllm_config
(
vllm_config
):
if
device
==
"cpu"
:
if
device
==
"cpu"
:
with
patch
(
"vllm.platforms.current_platform"
,
CpuPlatform
()):
with
patch
(
"vllm.platforms.current_platform"
,
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
block_size
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
block_size
)
...
@@ -180,7 +182,7 @@ def test_env(
...
@@ -180,7 +182,7 @@ def test_env(
expected
=
name
expected
=
name
assert
backend
.
get_name
()
==
expected
assert
backend
.
get_name
()
==
expected
elif
name
==
"FLASH_ATTN_MLA"
:
elif
name
==
"FLASH_ATTN_MLA"
:
from
vllm.attention.
util
s.fa_utils
import
(
from
vllm.
v1.
attention.
backend
s.fa_utils
import
(
flash_attn_supports_mla
,
flash_attn_supports_mla
,
)
)
...
@@ -217,27 +219,32 @@ def test_env(
...
@@ -217,27 +219,32 @@ def test_env(
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"cuda"
])
def
test_fp32_fallback
(
device
:
str
):
def
test_fp32_fallback
(
device
:
str
):
"""Test attention backend selection with fp32."""
"""Test attention backend selection with fp32."""
if
device
==
"cpu"
:
# Use default config (no backend specified)
with
patch
(
"vllm.platforms.current_platform"
,
CpuPlatform
()):
vllm_config
=
VllmConfig
()
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
,
16
)
assert
backend
.
get_name
()
==
"CPU_ATTN"
elif
device
==
"cuda"
:
with
set_current_vllm_config
(
vllm_config
):
with
patch
(
"vllm.platforms.current_platform"
,
CudaPlatform
()):
if
device
==
"cpu"
:
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
,
16
)
with
patch
(
"vllm.platforms.current_platform"
,
CpuPlatform
()):
assert
backend
.
get_name
()
==
"FLEX_ATTENTION"
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
,
16
)
assert
backend
.
get_name
()
==
"CPU_ATTN"
elif
device
==
"cuda"
:
with
patch
(
"vllm.platforms.current_platform"
,
CudaPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
,
16
)
assert
backend
.
get_name
()
==
"FLEX_ATTENTION"
def
test_flash_attn
(
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_flash_attn
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test FlashAttn validation."""
"""Test FlashAttn validation."""
pytest
.
skip
(
pytest
.
skip
(
"Skipping as current backend selector does not "
"Skipping as current backend selector does not "
"handle fallbacks when a backend is
set via env var
."
"handle fallbacks when a backend is
explicitly set
."
)
)
with
monkeypatch
.
context
()
as
m
:
attention_config
=
AttentionConfig
(
backend
=
AttentionBackendEnum
.
FLASH_ATTN
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLASH_ATTN"
)
vllm_config
=
VllmConfig
(
attention_config
=
attention_config
)
with
set_current_vllm_config
(
vllm_config
):
# Unsupported CUDA arch
# Unsupported CUDA arch
monkeypatch
.
setattr
(
torch
.
cuda
,
"get_device_capability"
,
lambda
_
=
None
:
(
7
,
5
))
monkeypatch
.
setattr
(
torch
.
cuda
,
"get_device_capability"
,
lambda
_
=
None
:
(
7
,
5
))
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
)
...
@@ -277,15 +284,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
...
@@ -277,15 +284,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
assert
backend
.
get_name
()
!=
"FLASH_ATTN"
assert
backend
.
get_name
()
!=
"FLASH_ATTN"
def
test_invalid_
env
(
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_invalid_
backend
(
):
"""Test that invalid attention backend names raise ValueError."""
"""Test that invalid attention backend names raise ValueError."""
with
(
with
(
monkeypatch
.
context
()
as
m
,
pytest
.
raises
(
ValueError
),
patch
(
"vllm.platforms.current_platform"
,
CudaPlatform
()),
):
):
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"INVALID"
)
# Invalid backend name should raise ValueError when creating enum
AttentionConfig
(
backend
=
AttentionBackendEnum
[
"INVALID"
])
# Should raise ValueError for invalid backend
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
get_attn_backend
(
32
,
torch
.
float16
,
None
,
16
)
assert
"Invalid value 'INVALID'"
in
str
(
exc_info
.
value
)
tests/kernels/attention/untest_flashinfer.py
View file @
7e63ef82
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
pytest
import
pytest
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
try
:
try
:
import
flashinfer
import
flashinfer
...
@@ -101,7 +102,7 @@ def test_flashinfer_decode_with_paged_kv(
...
@@ -101,7 +102,7 @@ def test_flashinfer_decode_with_paged_kv(
sliding_window
:
int
|
None
,
sliding_window
:
int
|
None
,
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
num_seqs
=
len
(
kv_lens
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
num_kv_heads
=
num_heads
[
1
]
...
@@ -196,7 +197,7 @@ def test_flashinfer_prefill_with_paged_kv(
...
@@ -196,7 +197,7 @@ def test_flashinfer_prefill_with_paged_kv(
sliding_window
:
int
|
None
,
sliding_window
:
int
|
None
,
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
num_seqs
=
len
(
seq_lens
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
...
@@ -299,7 +300,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
...
@@ -299,7 +300,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
)
->
None
:
)
->
None
:
pytest
.
skip
(
"TODO: fix the accuracy issue"
)
pytest
.
skip
(
"TODO: fix the accuracy issue"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
num_seqs
=
len
(
seq_lens
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
...
@@ -409,7 +410,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
...
@@ -409,7 +410,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
)
->
None
:
)
->
None
:
# test doesn't work for num_heads = (16,16)
# test doesn't work for num_heads = (16,16)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
set_random_seed
(
0
)
num_seqs
=
len
(
kv_lens
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
num_kv_heads
=
num_heads
[
1
]
...
...
tests/kernels/core/test_activation.py
View file @
7e63ef82
...
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.activation import (
...
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.activation import (
SiluAndMul
,
SiluAndMul
,
SwigluOAIAndMul
,
SwigluOAIAndMul
,
)
)
from
vllm.
platforms
import
current_platform
from
vllm.
utils.torch_utils
import
set_random_seed
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
...
@@ -45,6 +45,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e
...
@@ -45,6 +45,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_act_and_mul
(
def
test_act_and_mul
(
default_vllm_config
,
activation
:
str
,
activation
:
str
,
num_tokens
:
int
,
num_tokens
:
int
,
d
:
int
,
d
:
int
,
...
@@ -52,7 +53,7 @@ def test_act_and_mul(
...
@@ -52,7 +53,7 @@ def test_act_and_mul(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
)
if
activation
==
"silu_and_mul"
:
if
activation
==
"silu_and_mul"
:
...
@@ -122,6 +123,7 @@ def test_act_and_mul(
...
@@ -122,6 +123,7 @@ def test_act_and_mul(
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_activation
(
def
test_activation
(
default_vllm_config
,
activation
:
type
[
torch
.
nn
.
Module
],
activation
:
type
[
torch
.
nn
.
Module
],
num_tokens
:
int
,
num_tokens
:
int
,
d
:
int
,
d
:
int
,
...
@@ -129,7 +131,7 @@ def test_activation(
...
@@ -129,7 +131,7 @@ def test_activation(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
)
layer
=
activation
[
0
]()
layer
=
activation
[
0
]()
...
...
tests/kernels/core/test_fused_qk_norm_rope.py
View file @
7e63ef82
...
@@ -8,11 +8,13 @@ from tests.kernels.utils import opcheck
...
@@ -8,11 +8,13 @@ from tests.kernels.utils import opcheck
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float16
]
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float16
]
IS_NEOX
=
[
True
,
False
]
IS_NEOX
=
[
True
,
False
]
EPS_VALUES
=
[
1e-5
,
1e-6
]
EPS_VALUES
=
[
1e-5
,
1e-6
]
SEEDS
=
[
13
]
SEEDS
=
[
13
]
PARTIAL_ROPE
=
[
True
,
False
]
CUDA_DEVICES
=
[
"cuda:0"
]
CUDA_DEVICES
=
[
"cuda:0"
]
...
@@ -52,16 +54,19 @@ def _apply_qk_norm_rope(
...
@@ -52,16 +54,19 @@ def _apply_qk_norm_rope(
@
pytest
.
mark
.
parametrize
(
"is_neox"
,
IS_NEOX
)
@
pytest
.
mark
.
parametrize
(
"is_neox"
,
IS_NEOX
)
@
pytest
.
mark
.
parametrize
(
"eps"
,
EPS_VALUES
)
@
pytest
.
mark
.
parametrize
(
"eps"
,
EPS_VALUES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"rotary_ratio"
,
[
1.0
,
0.5
,
0.25
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_fused_qk_norm_rope_matches_reference
(
def
test_fused_qk_norm_rope_matches_reference
(
default_vllm_config
,
device
:
str
,
device
:
str
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
is_neox
:
bool
,
is_neox
:
bool
,
eps
:
float
,
eps
:
float
,
seed
:
int
,
seed
:
int
,
rotary_ratio
:
float
,
):
):
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
current_platform
.
seed_everything
(
seed
)
set_random_seed
(
seed
)
num_heads
,
num_kv_heads
,
head_dim
=
16
,
4
,
128
num_heads
,
num_kv_heads
,
head_dim
=
16
,
4
,
128
num_tokens
=
4
num_tokens
=
4
...
@@ -76,10 +81,10 @@ def test_fused_qk_norm_rope_matches_reference(
...
@@ -76,10 +81,10 @@ def test_fused_qk_norm_rope_matches_reference(
k_norm
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
k_norm
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
q_weight
=
q_norm
.
weight
.
data
q_weight
=
q_norm
.
weight
.
data
k_weight
=
k_norm
.
weight
.
data
k_weight
=
k_norm
.
weight
.
data
rotary_dim
=
int
(
head_dim
*
rotary_ratio
)
rope
=
RotaryEmbedding
(
rope
=
RotaryEmbedding
(
head_size
=
head_dim
,
head_size
=
head_dim
,
rotary_dim
=
head
_dim
,
rotary_dim
=
rotary
_dim
,
max_position_embeddings
=
4096
,
max_position_embeddings
=
4096
,
base
=
10000.0
,
base
=
10000.0
,
is_neox_style
=
is_neox
,
is_neox_style
=
is_neox
,
...
...
tests/kernels/core/test_fused_quant_layernorm.py
View file @
7e63ef82
...
@@ -147,6 +147,7 @@ def ops_impl(
...
@@ -147,6 +147,7 @@ def ops_impl(
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_rms_norm
(
def
test_rms_norm
(
default_vllm_config
,
num_tokens
:
int
,
num_tokens
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
add_residual
:
bool
,
add_residual
:
bool
,
...
...
tests/kernels/core/test_layernorm.py
View file @
7e63ef82
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
from
tests.kernels.quant_utils
import
FP8_DTYPE
from
tests.kernels.quant_utils
import
FP8_DTYPE
from
tests.kernels.utils
import
opcheck
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.
platforms
import
current_platform
from
vllm.
utils.torch_utils
import
set_random_seed
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
4096
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
4096
]
# Arbitrary values for testing
...
@@ -26,6 +26,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e
...
@@ -26,6 +26,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e
@
pytest
.
mark
.
parametrize
(
"strided_input"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"strided_input"
,
[
False
,
True
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_rms_norm
(
def
test_rms_norm
(
default_vllm_config
,
num_tokens
:
int
,
num_tokens
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
add_residual
:
bool
,
add_residual
:
bool
,
...
@@ -34,7 +35,7 @@ def test_rms_norm(
...
@@ -34,7 +35,7 @@ def test_rms_norm(
device
:
str
,
device
:
str
,
strided_input
:
bool
,
strided_input
:
bool
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
=
dtype
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
=
dtype
)
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
...
@@ -70,6 +71,80 @@ def test_rms_norm(
...
@@ -70,6 +71,80 @@ def test_rms_norm(
)
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"add_residual"
,
ADD_RESIDUAL
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"quant_scale"
,
[
0.01
,
1.0
,
10.0
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"strided_input"
,
[
False
,
True
])
def
test_fused_rms_norm_quant
(
num_tokens
:
int
,
hidden_size
:
int
,
add_residual
:
bool
,
dtype
:
torch
.
dtype
,
quant_scale
:
float
,
seed
:
int
,
device
:
str
,
strided_input
:
bool
,
)
->
None
:
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
weight
=
torch
.
empty
(
hidden_size
,
dtype
=
dtype
).
normal_
(
mean
=
1.0
,
std
=
0.1
)
scale
=
1
/
(
2
*
hidden_size
)
last_dim
=
2
*
hidden_size
if
strided_input
else
hidden_size
x_base
=
torch
.
randn
(
num_tokens
,
last_dim
,
dtype
=
dtype
)
x
=
x_base
[...,
:
hidden_size
]
assert
x
.
is_contiguous
()
!=
strided_input
x
*=
scale
if
add_residual
:
residual
=
torch
.
randn_like
(
x
)
*
scale
residual_fused
=
residual
.
clone
()
else
:
residual
=
residual_fused
=
None
out_norm
=
torch
.
empty_like
(
x
)
out_quant
=
torch
.
empty_like
(
x
,
dtype
=
FP8_DTYPE
)
out_quant_fused
=
torch
.
empty_like
(
out_quant
)
quant_scale_t
=
torch
.
tensor
(
quant_scale
,
dtype
=
torch
.
float32
)
if
add_residual
:
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
(
out_quant_fused
,
x
,
residual_fused
,
weight
,
quant_scale_t
,
1e-6
)
# Unfused kernel is in-place so it goes second
# Also use a separate clone of x to avoid modifying the input
x_unfused_base
=
x_base
.
clone
()
x_unfused
=
x_unfused_base
[...,
:
hidden_size
]
assert
x_unfused
.
is_contiguous
()
!=
strided_input
torch
.
ops
.
_C
.
fused_add_rms_norm
(
x_unfused
,
residual
,
weight
,
1e-6
)
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
out_quant
,
x_unfused
.
contiguous
(),
quant_scale_t
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
residual_fused
,
residual
,
atol
=
1e-2
,
rtol
=
1e-2
)
opcheck
(
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
,
(
out_quant_fused
,
x
,
residual_fused
,
weight
,
quant_scale_t
,
1e-6
),
)
else
:
torch
.
ops
.
_C
.
rms_norm_static_fp8_quant
(
out_quant_fused
,
x
,
weight
,
quant_scale_t
,
1e-6
)
torch
.
ops
.
_C
.
rms_norm
(
out_norm
,
x
,
weight
,
1e-6
)
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
out_quant
,
out_norm
,
quant_scale_t
)
opcheck
(
torch
.
ops
.
_C
.
rms_norm_static_fp8_quant
,
(
out_quant_fused
,
x
,
weight
,
quant_scale_t
,
1e-6
),
)
# @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
# @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
# @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
...
...
Prev
1
…
18
19
20
21
22
23
24
25
26
…
35
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment