Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
469e903b
Commit
469e903b
authored
Mar 28, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.2' into v0.8.2-dev
parents
389ebcf7
25f560a6
Changes
535
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
976 additions
and
362 deletions
+976
-362
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+91
-24
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+90
-60
tests/kernels/test_awq_marlin.py
tests/kernels/test_awq_marlin.py
+2
-7
tests/kernels/test_block_fp8.py
tests/kernels/test_block_fp8.py
+2
-2
tests/kernels/test_blocksparse_attention.py
tests/kernels/test_blocksparse_attention.py
+6
-6
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+96
-41
tests/kernels/test_cascade_flash_attn.py
tests/kernels/test_cascade_flash_attn.py
+4
-4
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+32
-34
tests/kernels/test_cutlass_2of4_sparse.py
tests/kernels/test_cutlass_2of4_sparse.py
+2
-3
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+15
-13
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+171
-131
tests/kernels/test_flashmla.py
tests/kernels/test_flashmla.py
+132
-0
tests/kernels/test_fused_quant_layernorm.py
tests/kernels/test_fused_quant_layernorm.py
+6
-6
tests/kernels/test_gguf.py
tests/kernels/test_gguf.py
+77
-7
tests/kernels/test_machete_mm.py
tests/kernels/test_machete_mm.py
+7
-7
tests/kernels/test_mamba_mixer2.py
tests/kernels/test_mamba_mixer2.py
+1
-2
tests/kernels/test_mamba_ssm_ssd.py
tests/kernels/test_mamba_ssm_ssd.py
+3
-5
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+86
-7
tests/kernels/test_nvfp4_scaled_mm.py
tests/kernels/test_nvfp4_scaled_mm.py
+150
-0
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+3
-3
No files found.
Too many changes to show.
To preserve performance only
535 of 535+
files are displayed.
Plain diff
Email patch
tests/kernels/test_attention.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
import
random
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
torch
...
...
@@ -17,6 +17,8 @@ if not current_platform.is_rocm():
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
vllm.attention.backends.xformers
import
_make_alibi_bias
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
# This will change depending on the compute capability.
# - 512 as a buffer
...
...
@@ -25,6 +27,7 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS
=
4321
# Arbitrary values for testing
PARTITION_SIZE
=
512
PARTITION_SIZE_ROCM
=
256
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
...
...
@@ -85,8 +88,8 @@ def ref_single_query_cached_kv_attention(
block_table
=
block_tables_lst
[
i
]
seq_len
=
int
(
seq_lens_lst
[
i
])
keys_lst
:
L
ist
[
torch
.
Tensor
]
=
[]
values_lst
:
L
ist
[
torch
.
Tensor
]
=
[]
keys_lst
:
l
ist
[
torch
.
Tensor
]
=
[]
values_lst
:
l
ist
[
torch
.
Tensor
]
=
[]
for
j
in
range
(
seq_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
...
...
@@ -133,7 +136,7 @@ def test_paged_attention(
kv_cache_factory
,
version
:
str
,
num_seqs
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
use_alibi
:
bool
,
block_size
:
int
,
...
...
@@ -146,6 +149,8 @@ def test_paged_attention(
or
(
version
==
"rocm"
and
head_size
not
in
(
64
,
128
))):
pytest
.
skip
()
global
PARTITION_SIZE
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
...
...
@@ -166,7 +171,7 @@ def test_paged_attention(
# Create the block tables.
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables_lst
:
L
ist
[
L
ist
[
int
]]
=
[]
block_tables_lst
:
l
ist
[
l
ist
[
int
]]
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
...
...
@@ -214,6 +219,9 @@ def test_paged_attention(
and
block_size
==
BLOCK_SIZES
[
0
]))
elif
version
in
(
"v2"
,
"rocm"
):
if
current_platform
.
is_rocm
()
and
version
==
"rocm"
:
PARTITION_SIZE
=
PARTITION_SIZE_ROCM
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
assert
PARTITION_SIZE
%
block_size
==
0
num_seqs
,
num_heads
,
head_size
=
output
.
shape
...
...
@@ -334,25 +342,31 @@ def test_paged_attention(
def
ref_multi_query_kv_attention
(
cu_seq_lens
:
L
ist
[
int
],
cu_seq_lens
:
l
ist
[
int
],
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
alibi_bias
:
Optional
[
list
[
torch
.
Tensor
]],
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
cu_seq_lens
)
-
1
ref_outputs
:
List
[
torch
.
Tensor
]
=
[]
ref_outputs
:
list
[
torch
.
Tensor
]
=
[]
if
alibi_bias
:
assert
len
(
alibi_bias
)
==
num_seqs
for
i
in
range
(
num_seqs
):
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
seq_len
=
end_idx
-
start_idx
# Create attention mask.
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
dtype
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
dtype
).
min
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
)
# Create attention mask. ALiBi already includes a tril causal mask.
if
alibi_bias
:
attn_mask
=
alibi_bias
[
i
]
else
:
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
dtype
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
dtype
).
min
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
)
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
...
...
@@ -366,7 +380,6 @@ def ref_multi_query_kv_attention(
return
torch
.
cat
(
ref_outputs
,
dim
=
0
)
# TODO(woosuk): Add tests for USE_ALIBI=True.
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_PREFILL_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
...
...
@@ -378,11 +391,12 @@ def ref_multi_query_kv_attention(
@
torch
.
inference_mode
()
def
test_multi_query_kv_attention
(
num_seqs
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
use_alibi
:
bool
=
False
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
...
...
@@ -408,16 +422,40 @@ def test_multi_query_kv_attention(
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_queries_per_kv
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
num_queries_per_kv
,
dim
=
1
)
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_lens
)
output
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
key
.
unsqueeze
(
0
),
value
.
unsqueeze
(
0
),
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
scale
,
)
output
=
output
.
squeeze
(
0
)
alibi_bias
=
None
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
)
attn_bias
=
_make_alibi_bias
(
alibi_slopes
,
num_kv_heads
,
dtype
,
seq_lens
)
output
=
torch
.
empty_like
(
query
)
start
=
0
# Dynamic sequence length not supported with custom attn_bias.
for
i
,
seq_len
in
enumerate
(
seq_lens
):
end
=
start
+
seq_len
out
=
xops
.
memory_efficient_attention_forward
(
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
value
[
None
,
start
:
end
],
attn_bias
=
attn_bias
[
i
],
p
=
0.0
,
scale
=
scale
)
output
[
start
:
end
].
copy_
(
out
.
view_as
(
query
[
start
:
end
]))
start
+=
seq_len
# xformers.AttentionBias to Tensor for use in reference impl.
alibi_bias
=
[
b
.
materialize
(
b
.
shape
,
device
=
device
).
squeeze
()
for
b
in
attn_bias
]
else
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_lens
)
output
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
key
.
unsqueeze
(
0
),
value
.
unsqueeze
(
0
),
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
scale
,
)
output
=
output
.
squeeze
(
0
)
cu_seq_lens
=
[
0
]
for
seq_len
in
seq_lens
:
...
...
@@ -428,8 +466,37 @@ def test_multi_query_kv_attention(
key
,
value
,
scale
,
alibi_bias
,
dtype
,
)
atol
=
get_default_atol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-5
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_PREFILL_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Xformers backend is not supported on ROCm."
)
@
torch
.
inference_mode
()
def
test_multi_query_kv_attention_with_alibi
(
num_seqs
:
int
,
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
)
->
None
:
return
test_multi_query_kv_attention
(
num_seqs
,
num_heads
,
head_size
,
dtype
,
seed
,
device
,
use_alibi
=
True
,
)
tests/kernels/test_attention_selector.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
from
unittest.mock
import
Mock
,
patch
from
unittest.mock
import
patch
import
pytest
import
torch
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm.attention.selector
import
_cached_get_attn_backend
,
get_attn_backend
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.openvino
import
OpenVinoPlatform
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.utils
import
STR_FLASH_ATTN_VAL
,
STR_INVALID_VAL
from
vllm.utils
import
STR_BACKEND_ENV_VAR
,
STR_FLASH_ATTN_VAL
,
STR_INVALID_VAL
from
vllm.platforms
import
current_platform
...
...
@@ -23,86 +22,117 @@ def clear_cache():
@
pytest
.
mark
.
parametrize
(
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
,
"OPENVINO"
]
if
not
current_platform
()
else
[
"ROCM_FLASH"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"openvino"
,
"hip"
,
"cuda"
])
def
test_env
(
name
:
str
,
device
:
str
,
monkeypatch
):
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
]
if
not
current_platform
.
is_rocm
()
else
[
"ROCM_FLASH"
])
@
pytest
.
mark
.
parametrize
(
"use_v1"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
def
test_env
(
name
:
str
,
use_v1
:
bool
,
device
:
str
,
monkeypatch
:
pytest
.
MonkeyPatch
,
):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
"""
override_backend_env_variable
(
monkeypatch
,
name
)
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
"TORCH_SDPA"
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
RocmPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
"ROCM_FLASH"
elif
device
==
"openvino"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
OpenVinoPlatform
()),
patch
.
dict
(
'sys.modules'
,
{
'openvino'
:
Mock
()}):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
"OPENVINO"
else
:
if
name
in
[
"XFORMERS"
,
"FLASHINFER"
]:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
if
use_v1
else
"0"
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
name
)
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
Cu
da
Platform
()):
C
p
uPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
name
def
test_flash_attn
(
monkeypatch
):
assert
backend
.
get_name
()
==
"TORCH_SDPA"
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
RocmPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
EXPECTED
=
"TRITON_ATTN_VLLM_V1"
if
use_v1
else
"ROCM_FLASH"
assert
backend
.
get_name
()
==
EXPECTED
else
:
if
name
in
[
"XFORMERS"
,
"FLASHINFER"
]:
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
EXPECTED
=
"FLASH_ATTN_VLLM_V1"
if
use_v1
else
name
assert
backend
.
get_name
()
==
EXPECTED
def
test_flash_attn
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to
# get_attn_backend
override_backend_env_variable
(
monkeypatch
,
STR_FLASH_ATTN_VAL
)
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
STR_FLASH_ATTN_VAL
)
# Unsupported CUDA arch
with
patch
(
"torch.cuda.get_device_capability"
,
return_value
=
(
7
,
5
)):
# Unsupported CUDA arch
monkeypatch
.
setattr
(
torch
.
cuda
,
"get_device_capability"
,
lambda
:
(
7
,
5
))
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Unsupported data type
backend
=
get_attn_backend
(
16
,
torch
.
float8_e4m3fn
,
None
,
16
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Reset the monkeypatch for subsequent tests
monkeypatch
.
undo
()
# Unsupported
kv cache
data type
backend
=
get_attn_backend
(
16
,
torch
.
float
16
,
"fp8"
,
16
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Unsupported data type
backend
=
get_attn_backend
(
16
,
torch
.
float
8_e4m3fn
,
None
,
16
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Unsupported
block siz
e
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
8
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Unsupported
kv cache data typ
e
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
"fp8"
,
16
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# flash-attn is not installed
with
patch
.
dict
(
'sys.modules'
,
{
'vllm_flash_attn'
:
None
}):
# Unsupported block size
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
8
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# flash-attn is not installed
import
sys
original_module
=
sys
.
modules
.
get
(
'vllm_flash_attn'
)
monkeypatch
.
setitem
(
sys
.
modules
,
'vllm_flash_attn'
,
None
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Unsupported head size
backend
=
get_attn_backend
(
17
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Restore the original module if it existed
if
original_module
is
not
None
:
monkeypatch
.
setitem
(
sys
.
modules
,
'vllm_flash_attn'
,
original_module
)
else
:
monkeypatch
.
delitem
(
sys
.
modules
,
'vllm_flash_attn'
,
raising
=
False
)
# Unsupported head size
backend
=
get_attn_backend
(
17
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
True
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
True
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
@
pytest
.
mark
.
parametrize
(
"use_v1"
,
[
True
,
False
])
def
test_invalid_env
(
use_v1
:
bool
,
monkeypatch
:
pytest
.
MonkeyPatch
):
with
monkeypatch
.
context
()
as
m
,
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
if
use_v1
else
"0"
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
STR_INVALID_VAL
)
def
test_invalid_env
(
monkeypatch
):
"""Ignore the invalid env variable if it is set."""
override_backend_env_variable
(
monkeypatch
,
STR_INVALID_VAL
)
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
# Test with head size 32
backend
=
get_attn_backend
(
32
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
==
"FLASH_ATTN"
EXPECTED
=
"FLASH_ATTN_VLLM_V1"
if
use_v1
else
"FLASH_ATTN"
assert
backend
.
get_name
()
==
EXPECTED
# when block size == 16, backend will fall back to XFORMERS
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
==
"XFORMERS"
# this behavior is not yet supported on V1.
if
use_v1
:
# TODO: support fallback on V1!
# https://github.com/vllm-project/vllm/issues/14524
pass
else
:
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
==
"XFORMERS"
tests/kernels/test_awq_marlin.py
View file @
469e903b
...
...
@@ -99,13 +99,8 @@ def test_fused_marlin_moe_awq(
num_bits
=
num_bits
,
)
torch_output
=
torch_moe
(
a
,
w_ref1
.
transpose
(
1
,
2
),
w_ref2
.
transpose
(
1
,
2
),
score
,
topk
,
)
torch_output
=
torch_moe
(
a
,
w_ref1
.
transpose
(
1
,
2
),
w_ref2
.
transpose
(
1
,
2
),
score
,
topk
,
None
)
assert
compute_max_diff
(
marlin_output
,
torch_output
)
<
4e-2
...
...
tests/kernels/test_block_fp8.py
View file @
469e903b
...
...
@@ -30,8 +30,8 @@ M_moe = [1, 7, 83, 512, 2048]
N_moe
=
[
4608
]
# [128, 4608, 13824]
K_moe
=
[
7168
]
# [256, 7168, 13824]
BLOCK_SIZE
=
[[
128
,
128
]]
E
=
[
256
]
# [8, 24, 128, 256]
TOP_KS
=
[
1
]
# [1, 2, 6]
E
=
[
8
,
24
]
# [8, 24, 128, 256]
TOP_KS
=
[
2
]
# [1, 2, 6]
OUT_DTYPES
=
[
torch
.
bfloat16
]
# [torch.float32, torch.half, torch.bfloat16]
SEEDS
=
[
0
]
...
...
tests/kernels/test_blocksparse_attention.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
import
random
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
torch
...
...
@@ -87,8 +87,8 @@ def ref_single_query_cached_kv_attention(
block_table
=
block_tables_lst
[
i
]
seq_len
=
int
(
seq_lens_lst
[
i
])
keys_lst
:
L
ist
[
torch
.
Tensor
]
=
[]
values_lst
:
L
ist
[
torch
.
Tensor
]
=
[]
keys_lst
:
l
ist
[
torch
.
Tensor
]
=
[]
values_lst
:
l
ist
[
torch
.
Tensor
]
=
[]
for
j
in
range
(
seq_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
...
...
@@ -162,7 +162,7 @@ def test_paged_attention(
kv_cache_factory
,
version
:
str
,
num_seqs
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
use_alibi
:
bool
,
block_size
:
int
,
...
...
@@ -331,7 +331,7 @@ def test_paged_attention(
def
ref_multi_query_kv_attention
(
cu_seq_lens
:
L
ist
[
int
],
cu_seq_lens
:
l
ist
[
int
],
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
...
@@ -376,7 +376,7 @@ def ref_multi_query_kv_attention(
@
torch
.
inference_mode
()
def
test_varlen_blocksparse_attention_prefill
(
num_seqs
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
blocksparse_local_blocks
:
int
,
blocksparse_vert_stride
:
int
,
...
...
tests/kernels/test_cache.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
import
random
from
typing
import
List
,
Tuple
import
pytest
import
torch
...
...
@@ -9,7 +8,6 @@ import torch
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.utils
import
align_to_256bytes
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
...
@@ -75,7 +73,7 @@ def test_copy_blocks(
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
dst_blocks
=
random
.
sample
(
remainig_blocks
,
2
*
num_mappings
)
block_mapping
:
L
ist
[
T
uple
[
int
,
int
]]
=
[]
block_mapping
:
l
ist
[
t
uple
[
int
,
int
]]
=
[]
for
i
in
range
(
num_mappings
):
src
=
src_blocks
[
i
]
dst1
=
dst_blocks
[
2
*
i
]
...
...
@@ -160,19 +158,20 @@ def test_reshape_and_cache(
device
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Using default kv_scale
k_scale
=
(
key
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
# Clone the KV caches.
if
kv_cache_dtype
==
"fp8"
:
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
,
k_scale
.
item
()
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
,
v_scale
.
item
()
)
else
:
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Using default kv_scale
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
...
...
@@ -183,9 +182,9 @@ def test_reshape_and_cache(
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
,
k_scale
.
item
()
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
,
v_scale
.
item
()
)
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
...
...
@@ -269,15 +268,16 @@ def test_reshape_and_cache_flash(
del
key_caches
del
value_caches
k_scale
=
(
key
.
amax
()
/
25
6.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
25
6.0
).
to
(
torch
.
float32
)
k_scale
=
(
key
.
amax
()
/
6
4
.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
6
4
.0
).
to
(
torch
.
float32
)
# Clone the KV caches.
if
kv_cache_dtype
==
"fp8"
:
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
,
k_scale
,
kv_cache_dtype
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
,
k_scale
.
item
(),
kv_cache_dtype
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
,
v_scale
,
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
,
v_scale
.
item
()
,
kv_cache_dtype
)
else
:
cloned_key_cache
=
key_cache
.
clone
()
...
...
@@ -341,7 +341,7 @@ def test_reshape_and_cache_flash(
@
torch
.
inference_mode
()
def
test_swap_blocks
(
kv_cache_factory
,
direction
:
T
uple
[
str
,
str
],
direction
:
t
uple
[
str
,
str
],
num_mappings
:
int
,
num_heads
:
int
,
head_size
:
int
,
...
...
@@ -452,22 +452,13 @@ def _create_mla_cache(
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
device
:
str
,
align_cache
:
bool
,
)
->
torch
.
Tensor
:
cache_dtype
=
torch
.
uint8
if
kv_cache_dtype
==
"fp8"
else
dtype
if
align_cache
:
alloc_entry_size
=
align_to_256bytes
(
entry_size
,
cache_dtype
)
alloc_shape
=
(
num_blocks
,
block_size
,
alloc_entry_size
)
cache_full
=
torch
.
zeros
(
alloc_shape
,
dtype
=
cache_dtype
,
device
=
device
)
cache
=
cache_full
[...,
:
entry_size
]
else
:
cache
=
torch
.
zeros
(
num_blocks
,
block_size
,
entry_size
,
dtype
=
cache_dtype
,
device
=
device
)
return
cache
return
torch
.
zeros
(
num_blocks
,
block_size
,
entry_size
,
dtype
=
cache_dtype
,
device
=
device
)
def
_fill_mla_cache
(
cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
):
...
...
@@ -490,7 +481,6 @@ def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"align_cache"
,
[
False
])
@
torch
.
inference_mode
()
def
test_concat_and_cache_mla
(
kv_lora_rank
:
int
,
...
...
@@ -502,7 +492,6 @@ def test_concat_and_cache_mla(
seed
:
int
,
device
:
str
,
kv_cache_dtype
:
str
,
align_cache
:
bool
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
...
...
@@ -522,7 +511,7 @@ def test_concat_and_cache_mla(
scale
=
torch
.
tensor
(
0.1
,
dtype
=
torch
.
float32
,
device
=
device
)
kv_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
ref_temp
=
torch
.
zeros
(
*
kv_cache
.
shape
,
dtype
=
dtype
,
device
=
device
)
for
i
in
range
(
num_tokens
):
...
...
@@ -578,7 +567,6 @@ def test_concat_and_cache_mla(
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"align_cache"
,
[
False
,
True
])
@
torch
.
inference_mode
()
def
test_copy_blocks_mla
(
kv_lora_rank
:
int
,
...
...
@@ -590,7 +578,6 @@ def test_copy_blocks_mla(
seed
:
int
,
device
:
str
,
kv_cache_dtype
:
str
,
align_cache
:
bool
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
...
...
@@ -600,7 +587,7 @@ def test_copy_blocks_mla(
kv_caches
=
[]
for
_
in
range
(
num_layers
):
kv_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
_fill_mla_cache
(
kv_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
kv_caches
.
append
(
kv_cache
)
...
...
@@ -644,7 +631,6 @@ def test_copy_blocks_mla(
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"align_cache"
,
[
False
,
True
])
@
torch
.
inference_mode
()
def
test_swap_blocks_mla
(
kv_lora_rank
:
int
,
...
...
@@ -655,7 +641,6 @@ def test_swap_blocks_mla(
seed
:
int
,
device
:
str
,
kv_cache_dtype
:
str
,
align_cache
:
bool
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
...
...
@@ -663,9 +648,9 @@ def test_swap_blocks_mla(
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
dst_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
)
_fill_mla_cache
(
dst_cache
,
kv_cache_dtype
)
...
...
@@ -685,8 +670,6 @@ def test_swap_blocks_mla(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
(
src_cache
,
dst_cache
,
block_mapping_tensor
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
cond
=
(
kv_lora_rank
==
KV_LORA_RANKS
[
0
]
and
qk_rope_head_dim
==
QK_ROPE_HEAD_DIMS
[
0
]),
)
ops
.
swap_blocks
(
src_cache
,
dst_cache
,
block_mapping_tensor
)
...
...
@@ -697,3 +680,75 @@ def test_swap_blocks_mla(
dst_cache
[
dst
].
cpu
(),
msg
=
f
"Block
{
src
}
from src should have been swapped to block "
f
"
{
dst
}
in dst_cache."
)
@
pytest
.
mark
.
parametrize
(
"kv_lora_rank"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"qk_rope_head_dim"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
])
# You can also test "fp8" if needed.
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_gather_cache_mla
(
kv_lora_rank
,
qk_rope_head_dim
,
block_size
,
num_blocks
,
max_seq_len
,
batch_size
,
dtype
,
kv_cache_dtype
,
device
):
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
seq_len_tensor
=
torch
.
randint
(
0
,
max_seq_len
+
1
,
(
batch_size
,
),
device
=
device
)
total_tokens
=
seq_len_tensor
.
sum
()
cu_seq_lens
=
torch
.
empty
((
batch_size
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
cu_seq_lens
[
0
]
=
0
cu_seq_lens
[
1
:]
=
seq_len_tensor
.
cumsum
(
dim
=
0
).
to
(
dtype
=
torch
.
int32
)
print
(
"seq_len_tensor"
,
seq_len_tensor
)
tot_blocks_tensor
=
(
seq_len_tensor
+
block_size
-
1
)
//
block_size
block_table
=
torch
.
empty
((
batch_size
,
num_blocks
),
dtype
=
torch
.
int32
,
device
=
device
)
for
b
in
range
(
batch_size
):
perm
=
torch
.
randperm
(
num_blocks
,
device
=
device
)
block_table
[
b
,
:]
=
perm
dst
=
torch
.
zeros
((
total_tokens
,
entry_size
),
dtype
=
src_cache
.
dtype
,
device
=
device
)
expected_batches
=
[]
for
b
in
range
(
batch_size
):
s
=
seq_len_tensor
[
b
]
if
s
==
0
:
continue
tot
=
tot_blocks_tensor
[
b
]
blocks
=
block_table
[
b
,
:
tot
].
tolist
()
gathered_rows
=
[]
for
i
in
range
(
tot
-
1
):
gathered_rows
.
append
(
src_cache
[
blocks
[
i
]])
remaining
=
s
-
(
tot
-
1
)
*
block_size
gathered_rows
.
append
(
src_cache
[
blocks
[
-
1
],
:
remaining
,
:])
batch_expected
=
torch
.
cat
(
gathered_rows
,
dim
=
0
)
expected_batches
.
append
(
batch_expected
)
expected
=
torch
.
cat
(
expected_batches
,
dim
=
0
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
gather_cache
,
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
None
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
)
ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
)
torch
.
testing
.
assert_close
(
dst
,
expected
)
tests/kernels/test_cascade_flash_attn.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
torch
...
...
@@ -25,7 +25,7 @@ DTYPES = [torch.float16, torch.bfloat16]
@
torch
.
inference_mode
()
def
test_merge_kernel
(
num_tokens
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
):
...
...
@@ -85,8 +85,8 @@ CASES = [
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
torch
.
inference_mode
()
def
test_cascade
(
seq_lens_and_common_prefix
:
T
uple
[
L
ist
[
T
uple
[
int
,
int
]],
int
],
num_heads
:
T
uple
[
int
,
int
],
seq_lens_and_common_prefix
:
t
uple
[
l
ist
[
t
uple
[
int
,
int
]],
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
...
...
tests/kernels/test_cutlass.py
View file @
469e903b
...
...
@@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_cutlass.py`.
"""
from
typing
import
Type
,
Optional
import
pytest
import
torch
...
...
@@ -82,7 +81,7 @@ def cutlass_fp8_gemm_helper(m: int,
a_scale_group_shape
:
tuple
,
b_scale_group_shape
:
tuple
,
use_bias
:
bool
,
out_dtype
:
T
ype
[
torch
.
dtype
]
=
torch
.
bfloat16
,
out_dtype
:
t
ype
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
...
...
@@ -120,7 +119,7 @@ def cutlass_int8_gemm_helper(m: int,
a_scale_group_shape
:
tuple
,
b_scale_group_shape
:
tuple
,
use_bias
:
bool
,
out_dtype
:
T
ype
[
torch
.
dtype
]
=
torch
.
bfloat16
,
out_dtype
:
t
ype
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
...
...
@@ -198,7 +197,7 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_output_dtype
(
a_scale_group_shape
,
b_scale_group_shape
,
out_dtype
:
T
ype
[
torch
.
dtype
],
out_dtype
:
t
ype
[
torch
.
dtype
],
use_bias
:
bool
):
cutlass_int8_gemm_helper
(
512
,
512
,
...
...
@@ -208,26 +207,25 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
use_bias
,
out_dtype
=
out_dtype
)
# @pytest.mark.parametrize("a_scale_group_shape",
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("b_scale_group_shape",
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
# reason="FP8 is not supported on this GPU type.")
# def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
# b_scale_group_shape,
# out_dtype: Type[torch.dtype],
# use_bias: bool):
# cutlass_fp8_gemm_helper(512,
# 512,
# 512,
# a_scale_group_shape,
# b_scale_group_shape,
# use_bias,
# out_dtype=out_dtype)
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape"
,
[
PER_TOKEN_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"b_scale_group_shape"
,
[
PER_OUT_CH_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_output_dtype
(
a_scale_group_shape
,
b_scale_group_shape
,
out_dtype
:
type
[
torch
.
dtype
],
use_bias
:
bool
):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
,
out_dtype
=
out_dtype
)
# @pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
...
...
@@ -238,7 +236,7 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
# reason="FP8 blockwise is not supported on this GPU type.")
# def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
# b_scale_group_shape,
# out_dtype:
T
ype[torch.dtype],
# out_dtype:
t
ype[torch.dtype],
# use_bias: bool):
# cutlass_fp8_gemm_helper(512,
# 512,
...
...
@@ -271,15 +269,15 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
use_bias
:
bool
,
device
:
str
):
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
,
out_dtype
=
torch
.
bfloat16
,
device
=
device
)
#
use_bias: bool, device: str):
#
cutlass_int8_gemm_helper(512,
#
512,
#
512,
#
a_scale_group_shape,
#
b_scale_group_shape,
#
use_bias,
#
out_dtype=torch.bfloat16,
#
device=device)
...
...
tests/kernels/test_cutlass_2of4_sparse.py
View file @
469e903b
...
...
@@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_semi_structured.py`.
"""
from
typing
import
Tuple
,
Type
import
pytest
import
torch
...
...
@@ -79,7 +78,7 @@ def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
def
make_rand_sparse_tensors
(
dtype
:
torch
.
dtype
,
m
:
int
,
n
:
int
,
k
:
int
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
a
=
torch
.
randn
((
m
,
k
),
device
=
'cuda'
)
b
=
torch
.
randn
((
n
,
k
),
device
=
'cuda'
).
t
()
...
...
@@ -167,7 +166,7 @@ MNK_FACTORS = [
@
pytest
.
mark
.
parametrize
(
"m, n, k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_sparse_gemm
(
m
:
int
,
k
:
int
,
n
:
int
,
dtype
:
T
ype
[
torch
.
dtype
],
def
test_cutlass_sparse_gemm
(
m
:
int
,
k
:
int
,
n
:
int
,
dtype
:
t
ype
[
torch
.
dtype
],
use_bias
:
bool
):
# Create tensors
...
...
tests/kernels/test_encoder_decoder_attn.py
View file @
469e903b
...
...
@@ -22,6 +22,16 @@ from vllm.config import VllmConfig, set_current_vllm_config
from
vllm.forward_context
import
set_forward_context
from
vllm.platforms
import
current_platform
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
Encoder-decoder is only supported on V0, so set
VLLM_USE_V1=0 for all tests in the module.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]
HEAD_SIZES
=
[
64
,
256
]
...
...
@@ -243,7 +253,7 @@ def _decoder_attn_setup(
test_pt
:
TestPoint
,
test_rsrcs
:
TestResources
,
block_base_addr
:
int
=
0
,
)
->
T
uple
[
QKVInputs
,
PhaseTestParameters
,
PhaseTestParameters
,
int
]:
)
->
t
uple
[
QKVInputs
,
PhaseTestParameters
,
PhaseTestParameters
,
int
]:
'''
Set up test vectors & data structures for self-attention test.
...
...
@@ -421,7 +431,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
test_pt
:
TestPoint
,
test_rsrcs
:
TestResources
,
block_base_addr
:
int
=
0
,
)
->
T
uple
[
PhaseTestParameters
,
PhaseTestParameters
]:
)
->
t
uple
[
PhaseTestParameters
,
PhaseTestParameters
]:
'''
Set up test vectors & data structures for cross-attention test.
...
...
@@ -644,11 +654,7 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
packed_qkv
.
query
.
device
),
attn_metadata
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
)
def
_run_decoder_self_attention_test
(
...
...
@@ -682,7 +688,6 @@ def _run_decoder_self_attention_test(
& attn_metadata
'''
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
with
set_forward_context
(
attn_metadata
,
vllm_config
):
...
...
@@ -695,8 +700,7 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
kv_cache
,
attn_metadata
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
)
def
_run_encoder_decoder_cross_attention_test
(
...
...
@@ -744,7 +748,6 @@ def _run_encoder_decoder_cross_attention_test(
assert
decoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
if
cross_test_params
is
None
:
key
=
None
value
=
None
...
...
@@ -762,8 +765,7 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
key
,
value
,
kv_cache
,
attn_metadata
)
return
attn
.
forward
(
reshaped_query
,
key
,
value
)
@
pytest
.
fixture
(
autouse
=
True
)
...
...
tests/kernels/test_flash_attn.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
torch
...
...
@@ -8,8 +8,8 @@ import torch
from
vllm.platforms
import
current_platform
if
current_platform
():
import
flash_attn
if
current_platform
.
is_rocm
():
from
flash_attn
import
flash_attn
_varlen_func
else
:
from
vllm.vllm_flash_attn
import
(
fa_version_unsupported_reason
,
flash_attn_varlen_func
,
...
...
@@ -20,6 +20,7 @@ NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES
=
[
128
,
256
]
BLOCK_SIZES
=
[
16
,
32
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
QDTYPES
=
[
None
,
torch
.
float8_e4m3fn
]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS
=
[
32768
,
2048
]
...
...
@@ -29,8 +30,8 @@ def ref_paged_attn(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
query_lens
:
L
ist
[
int
],
kv_lens
:
L
ist
[
int
],
query_lens
:
l
ist
[
int
],
kv_lens
:
l
ist
[
int
],
block_tables
:
torch
.
Tensor
,
scale
:
float
,
sliding_window
:
Optional
[
int
]
=
None
,
...
...
@@ -40,7 +41,7 @@ def ref_paged_attn(
block_tables
=
block_tables
.
cpu
().
numpy
()
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
outputs
:
L
ist
[
torch
.
Tensor
]
=
[]
outputs
:
l
ist
[
torch
.
Tensor
]
=
[]
start_idx
=
0
for
i
in
range
(
num_seqs
):
query_len
=
query_lens
[
i
]
...
...
@@ -79,91 +80,124 @@ def ref_paged_attn(
return
torch
.
cat
(
outputs
,
dim
=
0
)
if
not
current_platform
():
@
pytest
.
mark
.
parametrize
(
"use_out"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
use_out
:
bool
,
kv_lens
:
List
[
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
sliding_window
:
Optional
[
int
],
fa_version
:
int
,
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
if
not
is_fa_version_supported
(
fa_version
):
pytest
.
skip
(
f
"Flash attention version
{
fa_version
}
not supported due "
f
"to:
\"
{
fa_version_unsupported_reason
(
fa_version
)
}
\"
"
)
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
window_size
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
value_cache
=
torch
.
randn_like
(
key_cache
)
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
q
=
query
.
unsqueeze
(
1
)
out
=
torch
.
empty_like
(
q
)
if
use_out
else
None
output
=
flash_attn_with_kvcache
(
q
=
q
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
out
=
out
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
window_size
=
window_size
,
fa_version
=
fa_version
,
)
output
=
output
if
not
use_out
else
out
output
=
output
.
squeeze
(
1
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
,
sliding_window
=
sliding_window
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"flash_attn_with_paged_kv is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"use_out"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
QDTYPES
)
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
use_out
:
bool
,
kv_lens
:
list
[
int
],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
sliding_window
:
Optional
[
int
],
fa_version
:
int
,
q_dtype
:
Optional
[
torch
.
dtype
],
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
if
not
is_fa_version_supported
(
fa_version
):
pytest
.
skip
(
f
"Flash attention version
{
fa_version
}
not supported due "
f
"to:
\"
{
fa_version_unsupported_reason
(
fa_version
)
}
\"
"
)
if
q_dtype
is
not
None
and
(
dtype
!=
torch
.
bfloat16
or
fa_version
==
2
):
pytest
.
skip
(
"Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type"
)
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
window_size
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
value_cache
=
torch
.
randn_like
(
key_cache
)
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
q
=
query
.
unsqueeze
(
1
)
out
=
torch
.
empty_like
(
q
)
if
use_out
else
None
maybe_quantized_query
=
q
maybe_quantized_key_cache
=
key_cache
maybe_quantized_value_cache
=
value_cache
q_descale
=
None
k_descale
=
None
v_descale
=
None
if
q_dtype
is
not
None
:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query
=
query
.
to
(
q_dtype
)
maybe_quantized_key_cache
=
key_cache
.
to
(
q_dtype
)
maybe_quantized_value_cache
=
value_cache
.
to
(
q_dtype
)
scale_shape
=
(
num_seqs
,
num_kv_heads
)
q_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
k_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
v_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
output
=
flash_attn_with_kvcache
(
q
=
maybe_quantized_query
,
k_cache
=
maybe_quantized_key_cache
,
v_cache
=
maybe_quantized_value_cache
,
out
=
out
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
window_size
=
window_size
,
fa_version
=
fa_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
output
=
output
if
not
use_out
else
out
output
=
output
.
squeeze
(
1
)
atol
,
rtol
=
1.5e-2
,
1e-2
if
q_dtype
is
not
None
:
atol
,
rtol
=
1.5e-1
,
1.5e-1
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
,
sliding_window
=
sliding_window
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"varlen_with_paged_kv is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"use_out"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
[[(
1
,
1328
),
(
5
,
18
),
...
...
@@ -176,11 +210,12 @@ if not current_platform():
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
QDTYPES
)
@
torch
.
inference_mode
()
def
test_varlen_with_paged_kv
(
use_out
:
bool
,
seq_lens
:
L
ist
[
T
uple
[
int
,
int
]],
num_heads
:
T
uple
[
int
,
int
],
seq_lens
:
l
ist
[
t
uple
[
int
,
int
]],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
...
...
@@ -188,11 +223,15 @@ def test_varlen_with_paged_kv(
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
fa_version
:
int
,
q_dtype
:
Optional
[
torch
.
dtype
],
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
if
not
is_fa_version_supported
(
fa_version
):
pytest
.
skip
(
f
"Flash attention version
{
fa_version
}
not supported due "
f
"to:
\"
{
fa_version_unsupported_reason
(
fa_version
)
}
\"
"
)
if
q_dtype
is
not
None
and
(
dtype
!=
torch
.
bfloat16
or
fa_version
==
2
):
pytest
.
skip
(
"Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type"
)
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
...
...
@@ -219,9 +258,6 @@ def test_varlen_with_paged_kv(
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
cu_kv_lens
=
torch
.
tensor
([
0
]
+
kv_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
kv_lens
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
...
...
@@ -231,42 +267,43 @@ def test_varlen_with_paged_kv(
dtype
=
torch
.
int32
)
out
=
torch
.
empty_like
(
query
)
if
use_out
else
None
if
current_platform
():
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
out
=
out
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
# fa_version=fa_version,
)
else
:
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
out
=
out
,
cu_seqlens_q
=
cu_query_lens
,
seqused_k
=
kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
fa_version
=
fa_version
,
)
maybe_quantized_query
=
query
maybe_quantized_key_cache
=
key_cache
maybe_quantized_value_cache
=
value_cache
q_descale
=
None
k_descale
=
None
v_descale
=
None
if
q_dtype
is
not
None
:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query
=
query
.
to
(
q_dtype
)
maybe_quantized_key_cache
=
key_cache
.
to
(
q_dtype
)
maybe_quantized_value_cache
=
value_cache
.
to
(
q_dtype
)
scale_shape
=
(
num_seqs
,
num_kv_heads
)
q_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
k_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
v_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
output
=
flash_attn_varlen_func
(
q
=
maybe_quantized_query
,
k
=
maybe_quantized_key_cache
,
v
=
maybe_quantized_value_cache
,
out
=
out
,
cu_seqlens_q
=
cu_query_lens
,
seqused_k
=
kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
fa_version
=
fa_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
output
=
output
if
not
use_out
else
out
ref_output
=
ref_paged_attn
(
...
...
@@ -280,5 +317,8 @@ def test_varlen_with_paged_kv(
sliding_window
=
sliding_window
,
soft_cap
=
soft_cap
,
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
atol
,
rtol
=
1.5e-2
,
1e-2
if
q_dtype
is
not
None
:
atol
,
rtol
=
1.5e-1
,
1.5e-1
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
\ No newline at end of file
tests/kernels/test_flashmla.py
0 → 100644
View file @
469e903b
# Adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla.py
# SPDX-License-Identifier: Apache-2.0
import
math
import
random
import
pytest
import
torch
import
triton
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
get_mla_metadata
,
is_flashmla_supported
)
def
cal_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
)
->
None
:
x
,
y
=
x
.
double
(),
y
.
double
()
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
().
item
()
/
max
(
(
x
*
x
+
y
*
y
).
sum
().
item
(),
1e-12
)
assert
cos_diff
<
1e-5
FLASH_MLA_UNSUPPORTED_REASON
=
is_flashmla_supported
()[
1
]
\
if
not
is_flashmla_supported
()[
0
]
else
"FlashMLA is supported"
@
pytest
.
mark
.
skipif
(
not
is_flashmla_supported
()[
0
],
reason
=
FLASH_MLA_UNSUPPORTED_REASON
)
@
pytest
.
mark
.
parametrize
(
"b"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"s_q"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"mean_sk"
,
[
4096
,
8192
])
@
pytest
.
mark
.
parametrize
(
"h_q"
,
[
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"h_kv"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
576
])
@
pytest
.
mark
.
parametrize
(
"dv"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
torch
.
inference_mode
()
def
test_flash_mla
(
b
,
s_q
,
mean_sk
,
h_q
,
h_kv
,
d
,
dv
,
block_size
,
causal
,
varlen
):
# TODO: parametrize using pytest
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
manual_seed
(
0
)
random
.
seed
(
0
)
print
(
f
"
{
b
=
}
,
{
s_q
=
}
,
{
mean_sk
=
}
,
{
h_q
=
}
,
{
h_kv
=
}
, "
f
"
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
"
)
cache_seqlens
=
torch
.
full
((
b
,
),
mean_sk
,
dtype
=
torch
.
int32
)
if
varlen
:
for
i
in
range
(
b
):
cache_seqlens
[
i
]
=
max
(
random
.
normalvariate
(
mean_sk
,
mean_sk
/
2
),
s_q
)
total_seqlens
=
cache_seqlens
.
sum
().
item
()
max_seqlen
=
cache_seqlens
.
max
().
item
()
max_seqlen_pad
=
triton
.
cdiv
(
max_seqlen
,
256
)
*
256
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
for
i
in
range
(
b
):
blocked_k
.
view
(
b
,
max_seqlen_pad
,
h_kv
,
d
)[
i
,
cache_seqlens
[
i
].
item
():]
=
float
(
"nan"
)
blocked_v
=
blocked_k
[...,
:
dv
]
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
def
flash_mla
():
return
flash_mla_with_kvcache
(
q
,
blocked_k
,
block_table
,
cache_seqlens
,
dv
,
tile_scheduler_metadata
,
num_splits
,
causal
=
causal
,
)
def
scaled_dot_product_attention
(
query
,
key
,
value
,
is_causal
=
False
):
query
=
query
.
float
()
key
=
key
.
float
()
value
=
value
.
float
()
key
=
key
.
repeat_interleave
(
h_q
//
h_kv
,
dim
=
0
)
value
=
value
.
repeat_interleave
(
h_q
//
h_kv
,
dim
=
0
)
attn_weight
=
query
@
key
.
transpose
(
-
2
,
-
1
)
/
math
.
sqrt
(
query
.
size
(
-
1
))
if
is_causal
:
s_q
=
query
.
shape
[
-
2
]
s_k
=
key
.
shape
[
-
2
]
attn_bias
=
torch
.
zeros
(
s_q
,
s_k
,
dtype
=
query
.
dtype
)
temp_mask
=
torch
.
ones
(
s_q
,
s_k
,
dtype
=
torch
.
bool
).
tril
(
diagonal
=
s_k
-
s_q
)
attn_bias
.
masked_fill_
(
temp_mask
.
logical_not
(),
float
(
"-inf"
))
attn_bias
.
to
(
query
.
dtype
)
attn_weight
+=
attn_bias
lse
=
attn_weight
.
logsumexp
(
dim
=-
1
)
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
,
dtype
=
torch
.
float32
)
return
attn_weight
@
value
,
lse
def
ref_mla
():
out
=
torch
.
empty
(
b
,
s_q
,
h_q
,
dv
,
dtype
=
torch
.
float32
)
lse
=
torch
.
empty
(
b
,
h_q
,
s_q
,
dtype
=
torch
.
float32
)
for
i
in
range
(
b
):
begin
=
i
*
max_seqlen_pad
end
=
begin
+
cache_seqlens
[
i
]
ref_O
,
LSE
=
scaled_dot_product_attention
(
q
[
i
].
transpose
(
0
,
1
),
blocked_k
.
view
(
-
1
,
h_kv
,
d
)[
begin
:
end
].
transpose
(
0
,
1
),
blocked_v
.
view
(
-
1
,
h_kv
,
dv
)[
begin
:
end
].
transpose
(
0
,
1
),
is_causal
=
causal
,
)
out
[
i
]
=
ref_O
.
transpose
(
0
,
1
)
lse
[
i
]
=
LSE
return
out
,
lse
out_flash
,
lse_flash
=
flash_mla
()
out_torch
,
lse_torch
=
ref_mla
()
cal_diff
(
out_flash
,
out_torch
,
"out"
)
cal_diff
(
lse_flash
,
lse_torch
,
"lse"
)
t
=
triton
.
testing
.
do_bench
(
flash_mla
,
fast_flush
=
False
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
"
f
"TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
tests/kernels/test_fused_quant_layernorm.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Union
import
pytest
import
torch
...
...
@@ -39,7 +39,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
def
ref_rms_norm
(
rms_norm_layer
:
RMSNorm
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
residual
is
not
None
:
residual
=
residual
.
clone
()
out
,
residual
=
rms_norm_layer
.
forward_native
(
x
,
residual
)
...
...
@@ -54,7 +54,7 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
],
scale_ub
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
scale_ub
is
not
None
:
assert
quant_dtype
==
torch
.
float8_e4m3fn
...
...
@@ -78,7 +78,7 @@ def ref_impl(rms_norm_layer: RMSNorm,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
],
scale_ub
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
return
ref_dynamic_per_token_quant
(
rms_norm_layer
,
x
,
quant_dtype
,
residual
,
scale_ub
)
...
...
@@ -88,7 +88,7 @@ def ops_dynamic_per_token_quant(weight: torch.Tensor,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
],
scale_ub
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
residual
is
not
None
:
residual
=
residual
.
clone
()
out
,
scales
=
ops
.
rms_norm_dynamic_per_token_quant
(
x
,
weight
,
EPS
,
...
...
@@ -102,7 +102,7 @@ def ops_impl(weight: torch.Tensor,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
],
scale_ub
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
return
ops_dynamic_per_token_quant
(
weight
,
x
,
quant_dtype
,
residual
,
scale_ub
)
...
...
tests/kernels/test_gguf.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
from
pathlib
import
Path
from
typing
import
List
import
pytest
import
os
...
...
@@ -10,23 +9,37 @@ from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
from
huggingface_hub
import
snapshot_download
import
vllm._custom_ops
as
ops
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.quantization.gguf
import
_fused_moe_gguf
from
vllm.platforms
import
current_platform
from
..utils
import
models_path_prefix
# GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
# GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
GGUF_SAMPLE
=
os
.
path
.
join
(
models_path_prefix
,
"Isotr0py/test-gguf-sample"
)
GGUF_SAMPLE_MOE
=
os
.
path
.
join
(
models_path_prefix
,
"SzymonOzog/test-gguf-moe-sample"
)
def
get_gguf_sample_tensors
(
hidden_size
:
int
,
quant_type
:
GGMLQuantizationType
)
->
L
ist
[
ReaderTensor
]:
quant_type
:
GGMLQuantizationType
)
->
l
ist
[
ReaderTensor
]:
sample_dir
=
GGUF_SAMPLE
filename
=
f
"Quant_
{
quant_type
.
name
}
_
{
hidden_size
}
.gguf"
sample_file
=
Path
(
sample_dir
)
/
filename
return
GGUFReader
(
sample_file
).
tensors
DTYPES
=
[
torch
.
half
]
def
get_gguf_MoE_tensors
(
hidden_size
:
int
,
quant_type
:
GGMLQuantizationType
)
->
list
[
ReaderTensor
]:
sample_dir
=
GGUF_SAMPLE_MOE
filename
=
f
"Quant_
{
quant_type
.
name
}
_
{
hidden_size
}
.gguf"
sample_file
=
Path
(
sample_dir
)
/
filename
return
GGUFReader
(
sample_file
).
tensors
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float32
]
# Hidden_size for testing, must match the sample file in HF repo,
# we have `hidden_size = 256, 1024` for test in HF repo currently.
HIDDEN_SIZES
=
[
256
,
1024
]
...
...
@@ -56,7 +69,7 @@ QUANT_TYPES = [
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
]
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
QUANT_TYPES
)
@
torch
.
inference_mode
()
def
test_dequantize
(
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
...
...
@@ -126,7 +139,64 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
ref_output
=
x
@
weight
.
T
qweight
=
torch
.
tensor
(
tensor
.
data
,
device
=
"cuda"
)
output
=
ops
.
ggml_mul_mat_a8
(
qweight
,
x
,
quant_type
,
qweight
.
shape
[
0
]).
to
(
dtype
)
output
=
ops
.
ggml_mul_mat_a8
(
qweight
,
x
,
quant_type
,
qweight
.
shape
[
0
])
atols
=
{
torch
.
half
:
1
,
torch
.
bfloat16
:
1.5
,
torch
.
float
:
1.2
}
# test matrix has inputs centered around 0 and lower precision from
# bfloat16 tends to accumulate and can greatly inflate rtol
# since outputs are also very close to 0
rtols
=
{
torch
.
half
:
1e-1
,
torch
.
bfloat16
:
1e4
,
torch
.
float
:
2e1
}
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atols
[
dtype
],
rtol
=
rtols
[
dtype
])
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1
,
rtol
=
1e-1
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
# k-quants
GGMLQuantizationType
.
Q2_K
,
GGMLQuantizationType
.
Q3_K
,
GGMLQuantizationType
.
Q4_K
,
GGMLQuantizationType
.
Q5_K
,
GGMLQuantizationType
.
Q6_K
,
# standard quants
GGMLQuantizationType
.
Q4_0
,
GGMLQuantizationType
.
Q5_0
,
GGMLQuantizationType
.
Q8_0
,
])
@
torch
.
inference_mode
()
def
test_moe
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
quant_type
:
GGMLQuantizationType
,
top_k
:
int
):
current_platform
.
seed_everything
(
0
)
H
,
E
=
1024
,
256
x
=
torch
.
rand
((
num_tokens
,
H
),
dtype
=
dtype
,
device
=
"cuda"
)
topk_weights
=
torch
.
rand
(
num_tokens
,
top_k
,
device
=
"cuda"
,
dtype
=
dtype
)
topk_ids
=
torch
.
randint
(
0
,
E
,
(
num_tokens
,
top_k
),
device
=
"cuda"
)
tensors
=
get_gguf_MoE_tensors
(
hidden_size
,
quant_type
)
w13
=
tensors
[
0
]
w2
=
tensors
[
1
]
w13_dequant
=
torch
.
tensor
(
dequantize
(
w13
.
data
,
quant_type
),
device
=
"cuda"
).
to
(
dtype
)
w2_dequant
=
torch
.
tensor
(
dequantize
(
w2
.
data
,
quant_type
),
device
=
"cuda"
).
to
(
dtype
)
act
=
SiluAndMul
()
output
=
_fused_moe_gguf
(
x
,
torch
.
tensor
(
w13
.
data
,
device
=
"cuda"
),
torch
.
tensor
(
w2
.
data
,
device
=
"cuda"
),
topk_weights
,
topk_ids
,
quant_type
,
quant_type
,
act
)
ref_output
=
fused_experts
(
x
,
w13_dequant
,
w2_dequant
,
topk_weights
,
topk_ids
).
reshape
(
output
.
shape
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1
,
rtol
=
1e-1
)
tests/kernels/test_machete_mm.py
View file @
469e903b
...
...
@@ -6,7 +6,7 @@ Run `pytest tests/kernels/test_machete_mm.py`.
import
math
from
dataclasses
import
dataclass
,
fields
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
torch
...
...
@@ -45,7 +45,7 @@ MNK_SHAPES = [
(
1024
,
8192
,
4096
),
]
GROUP_SIZES_TO_TEST
:
L
ist
[
Optional
[
int
]]
=
[
128
,
-
1
]
GROUP_SIZES_TO_TEST
:
l
ist
[
Optional
[
int
]]
=
[
128
,
-
1
]
@
dataclass
...
...
@@ -75,7 +75,7 @@ class Tensors:
# Ch Scales Type, Tok Scales Type)
# NOTE: None "Scale Type" means the act type is floating point
# None "Output Type" means the output type is the same as the act type
TestTypeTuple
=
T
uple
[
L
ist
[
torch
.
dtype
],
ScalarType
,
Optional
[
torch
.
dtype
],
TestTypeTuple
=
t
uple
[
l
ist
[
torch
.
dtype
],
ScalarType
,
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
],
bool
]
TEST_TYPES
=
[
# GPTQ style
...
...
@@ -136,7 +136,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
return
zps
if
zps
is
None
else
-
1
*
s
*
(
zps
.
to
(
s
.
dtype
))
def
group_size_valid
(
shape
:
T
uple
[
int
,
int
,
int
],
def
group_size_valid
(
shape
:
t
uple
[
int
,
int
,
int
],
group_size
:
Optional
[
int
])
->
bool
:
return
group_size
is
None
or
group_size
==
-
1
or
group_size
%
shape
[
2
]
==
0
...
...
@@ -166,7 +166,7 @@ def machete_quantize_and_pack(atype: torch.dtype,
return
w_ref
,
w_q_machete
,
w_s
,
w_zp
def
create_test_tensors
(
shape
:
T
uple
[
int
,
int
,
int
],
def
create_test_tensors
(
shape
:
t
uple
[
int
,
int
,
int
],
types
:
TypeConfig
,
group_size
:
Optional
[
int
],
subset_stride_factor
:
Optional
[
int
]
=
None
)
->
Tensors
:
...
...
@@ -265,7 +265,7 @@ def machete_mm_test_helper(types: TypeConfig,
@
pytest
.
mark
.
parametrize
(
"types"
,
TEST_TYPES
)
def
test_machete_all_schedules
(
shape
,
types
:
TypeConfig
):
group_sizes
:
L
ist
[
Optional
[
int
]]
=
[]
group_sizes
:
l
ist
[
Optional
[
int
]]
=
[]
if
types
.
group_scale_type
is
None
:
group_sizes
=
[
None
]
else
:
...
...
@@ -294,7 +294,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
ids
=
lambda
x
:
"x"
.
join
(
str
(
v
)
for
v
in
x
))
@
pytest
.
mark
.
parametrize
(
"types"
,
TEST_TYPES
)
def
test_machete_heuristic
(
shape
,
types
:
TypeConfig
):
group_sizes
:
L
ist
[
Optional
[
int
]]
=
[]
group_sizes
:
l
ist
[
Optional
[
int
]]
=
[]
if
types
.
group_scale_type
is
None
:
group_sizes
=
[
None
]
else
:
...
...
tests/kernels/test_mamba_mixer2.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
import
unittest
from
typing
import
Tuple
import
pytest
import
torch
...
...
@@ -29,7 +28,7 @@ from vllm.utils import update_environment_variables
def
test_mixer2_gated_norm_multi_gpu
(
batch_size
:
int
,
seq_len
:
int
,
hidden_size_n_groups
:
T
uple
[
int
,
int
],
hidden_size_n_groups
:
t
uple
[
int
,
int
],
dtype
:
torch
.
dtype
,
device
:
str
=
'cuda'
,
):
...
...
tests/kernels/test_mamba_ssm_ssd.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
Tuple
import
pytest
import
torch
import
torch.nn.functional
as
F
...
...
@@ -134,7 +132,7 @@ def generate_continous_batched_examples(example_lens_by_batch,
# given a tuple of lengths for each example in the batch
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
# 4 examples from second eg, etc
def
get_continuous_batch
(
example_lens
:
T
uple
[
int
,
...]):
def
get_continuous_batch
(
example_lens
:
t
uple
[
int
,
...]):
indices
=
[]
for
i
,
x
in
enumerate
(
example_lens
):
...
...
@@ -264,8 +262,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken
:
D
ict
=
{}
# map: eg -> pointer to last taken sample
exhausted
:
D
ict
=
{}
# map: eg -> boolean indicating example is exhausted
last_taken
:
d
ict
=
{}
# map: eg -> pointer to last taken sample
exhausted
:
d
ict
=
{}
# map: eg -> boolean indicating example is exhausted
states
=
None
for
Y_min
,
cu_seqlens
,
sed_idx
,
(
A
,
dt
,
X
,
B
,
...
...
tests/kernels/test_moe.py
View file @
469e903b
...
...
@@ -3,8 +3,11 @@
Run `pytest tests/kernels/test_moe.py`.
"""
import
pytest
import
torch
from
torch.nn
import
Parameter
from
torch.nn
import
functional
as
F
from
transformers
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
...
...
@@ -26,6 +29,7 @@ from vllm.platforms import current_platform
from
vllm.scalar_type
import
scalar_types
NUM_EXPERTS
=
[
8
,
64
]
EP_SIZE
=
[
1
,
4
]
TOP_KS
=
[
2
,
6
]
...
...
@@ -34,24 +38,64 @@ TOP_KS = [2, 6]
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
True
,
False
])
def
test_fused_moe
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
padding
:
bool
,
):
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
if
ep_size
>
1
:
local_e
=
e
//
ep_size
e_ids
=
torch
.
randint
(
0
,
e
,
(
local_e
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
=
torch
.
full
((
e
,
),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
[
e_ids
]
=
torch
.
arange
(
local_e
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
w1
=
w1
[
e_ids
]
w2
=
w2
[
e_ids
]
else
:
e_map
=
None
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
e_map
)
iterative_output
=
iterative_moe
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
# Pad the weight if moe padding is enabled
if
padding
:
w1
=
F
.
pad
(
w1
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
torch
.
cuda
.
empty_cache
()
w2
=
F
.
pad
(
w2
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
torch
.
cuda
.
empty_cache
()
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
iterative_output
=
iterative_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch
.
testing
.
assert_close
(
iterative_output
,
torch_output
,
atol
=
2e-2
,
...
...
@@ -63,13 +107,14 @@ def test_fused_moe(
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"has_zp"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"weight_bits"
,
[
4
,
8
])
def
test_fused_moe_wn16
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
has_zp
:
bool
,
weight_bits
:
int
):
ep_size
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
has_zp
:
bool
,
weight_bits
:
int
):
print
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
group_size
,
has_zp
,
weight_bits
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
...
...
@@ -130,6 +175,25 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
if
has_zp
:
w_qzeros
[
expert_id
]
=
qzeros
if
ep_size
>
1
:
local_e
=
e
//
ep_size
e_ids
=
torch
.
randint
(
0
,
e
,
(
local_e
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
=
torch
.
full
((
e
,
),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
[
e_ids
]
=
torch
.
arange
(
local_e
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
w1_ref
=
w1_ref
[
e_ids
]
w2_ref
=
w2_ref
[
e_ids
]
w1_qweight
=
w1_qweight
[
e_ids
]
w2_qweight
=
w2_qweight
[
e_ids
]
w1_scales
=
w1_scales
[
e_ids
]
w2_scales
=
w2_scales
[
e_ids
]
w1_qzeros
=
w1_qzeros
[
e_ids
]
w2_qzeros
=
w2_qzeros
[
e_ids
]
else
:
e_map
=
None
triton_output
=
fused_moe
(
a
,
w1_qweight
,
w2_qweight
,
...
...
@@ -138,19 +202,22 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
renormalize
=
False
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
global_num_experts
=
e
,
expert_map
=
e_map
,
w1_scale
=
w1_scales
,
w2_scale
=
w2_scales
,
w1_zp
=
w1_qzeros
if
has_zp
else
None
,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
group_size
])
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
)
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
,
e_map
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
True
,
False
])
@
torch
.
inference_mode
()
def
test_mixtral_moe
(
dtype
:
torch
.
dtype
):
def
test_mixtral_moe
(
dtype
:
torch
.
dtype
,
padding
:
bool
):
"""Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""
...
...
@@ -164,6 +231,7 @@ def test_mixtral_moe(dtype: torch.dtype):
intermediate_size
=
config
.
intermediate_size
,
params_dtype
=
dtype
,
tp_size
=
1
,
dp_size
=
1
,
).
cuda
()
# Load the weights
...
...
@@ -179,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype):
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs
=
hf_inputs
.
flatten
(
0
,
1
)
# Pad the weight if moe padding is enabled
if
padding
:
vllm_moe
.
experts
.
w13_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w13_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
vllm_moe
.
experts
.
w2_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w2_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
# Run forward passes for both MoE blocks
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
vllm_states
=
vllm_moe
.
forward
(
vllm_inputs
)
...
...
tests/kernels/test_nvfp4_scaled_mm.py
0 → 100644
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
if
not
current_platform
.
has_device_capability
(
100
):
pytest
.
skip
(
reason
=
"Nvfp4 Requires compute capability of 10 or above."
,
allow_module_level
=
True
)
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
# m, n, k
SHAPES
=
[(
128
,
128
,
64
),
(
128
,
128
,
128
),
(
256
,
128
,
64
),
(
128
,
256
,
128
)]
PAD_SHAPES
=
[(
150
,
128
,
64
),
(
128
,
128
,
96
)]
SHAPES
.
extend
(
PAD_SHAPES
)
SEEDS
=
[
42
]
CUDA_DEVICES
=
[
'cuda:0'
]
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1fn
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
kE2M1ToFloatArray
=
[
0.
,
0.5
,
1.
,
1.5
,
2.
,
3.
,
4.
,
6.
,
]
def
e2m1_to_fp32
(
int4_value
):
signBit
=
(
int4_value
&
0x8
)
int4_absValue
=
int4_value
&
0x7
float_result
=
kE2M1ToFloatArray
[
int4_absValue
]
if
(
signBit
):
float_result
=
-
float_result
return
float_result
def
break_fp4_bytes
(
a
,
dtype
):
assert
(
a
.
dtype
==
torch
.
uint8
)
m
,
n
=
a
.
shape
a
=
a
.
flatten
()
# Get upper 4 bits
highHalfByte
=
(
a
&
0xF0
)
>>
4
# Get lower 4 bits
lowHalfByte
=
a
&
0x0F
fH
=
torch
.
tensor
([
e2m1_to_fp32
(
x
)
for
x
in
highHalfByte
]).
to
(
a
.
device
)
fL
=
torch
.
tensor
([
e2m1_to_fp32
(
x
)
for
x
in
lowHalfByte
]).
to
(
a
.
device
)
# [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC]
out
=
torch
.
stack
((
fL
,
fH
),
dim
=-
1
).
reshape
(
m
,
n
*
2
)
return
out
def
convert_swizzled_to_linear
(
a_sf_swizzled
:
torch
.
Tensor
,
m
,
k
,
block_size
):
sf_m
,
sf_k
=
a_sf_swizzled
.
shape
m_tiles
=
(
m
+
128
-
1
)
//
128
f
=
block_size
*
4
k_tiles
=
(
k
+
f
-
1
)
//
f
tmp
=
torch
.
reshape
(
a_sf_swizzled
,
(
1
,
m_tiles
,
k_tiles
,
32
,
4
,
4
))
tmp
=
torch
.
permute
(
tmp
,
(
0
,
1
,
4
,
3
,
2
,
5
))
out
=
tmp
.
reshape
(
m_tiles
*
128
,
k_tiles
*
f
//
block_size
)
return
out
[
0
:
m
,
0
:
k
]
def
dequantize_to_dtype
(
tensor_fp4
,
tensor_sf
,
global_scale
,
dtype
,
device
,
block_size
=
16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert
tensor_fp4
.
dtype
==
torch
.
uint8
m
,
packed_k
=
tensor_fp4
.
shape
k
=
packed_k
*
2
tensor_f32
=
break_fp4_bytes
(
tensor_fp4
,
dtype
)
tensor_f32
=
tensor_f32
.
reshape
(
m
,
k
//
block_size
,
block_size
)
tensor_sf
=
tensor_sf
.
view
(
torch
.
float8_e4m3fn
)
tensor_sf
=
convert_swizzled_to_linear
(
tensor_sf
,
m
,
k
,
block_size
)
tensor_sf_dtype
=
tensor_sf
.
to
(
torch
.
float32
)
/
global_scale
# scale the tensor
out
=
(
tensor_f32
*
tensor_sf_dtype
.
unsqueeze
(
-
1
)).
reshape
(
m
,
k
)
return
out
def
get_ref_results
(
a_fp4
,
b_fp4
,
a_sf
,
b_sf
,
a_global_scale
,
b_global_scale
,
m
,
n
,
dtype
,
block_size
,
device
):
_
,
m_k
=
a_fp4
.
shape
_
,
n_k
=
b_fp4
.
shape
assert
(
m_k
==
n_k
)
a_in_dtype
=
dequantize_to_dtype
(
a_fp4
,
a_sf
,
a_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
b_in_dtype
=
dequantize_to_dtype
(
b_fp4
,
b_sf
,
b_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
return
torch
.
matmul
(
a_in_dtype
,
b_in_dtype
.
t
())
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
SHAPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_nvfp4_gemm
(
dtype
:
torch
.
dtype
,
shape
:
tuple
[
int
,
int
,
int
],
seed
:
int
,
device
:
str
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
m
,
n
,
packed_k
=
shape
k
=
packed_k
*
2
block_size
=
16
a_dtype
=
torch
.
randn
((
m
,
k
),
dtype
=
dtype
,
device
=
device
)
b_dtype
=
torch
.
randn
((
n
,
k
),
dtype
=
dtype
,
device
=
device
)
a_global_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a_dtype
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
b_global_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
b_dtype
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
alpha
=
1.
/
(
a_global_scale
*
b_global_scale
)
a_fp4
,
a_scale_interleaved
=
ops
.
scaled_fp4_quant
(
a_dtype
,
a_global_scale
)
b_fp4
,
b_scale_interleaved
=
ops
.
scaled_fp4_quant
(
b_dtype
,
b_global_scale
)
expected_out
=
get_ref_results
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
a_global_scale
,
b_global_scale
,
m
,
n
,
dtype
,
block_size
,
device
)
out
=
ops
.
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
dtype
)
torch
.
testing
.
assert_close
(
out
,
expected_out
.
to
(
dtype
=
dtype
),
atol
=
1e-1
,
rtol
=
1e-1
)
tests/kernels/test_pos_encoding.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
from
itertools
import
accumulate
,
product
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Callable
,
Optional
import
pytest
import
torch
...
...
@@ -179,7 +179,7 @@ def test_batched_rotary_embedding_multi_lora(
torch
.
set_default_device
(
device
)
if
rotary_dim
is
None
:
rotary_dim
=
head_size
scaling_factors
:
L
ist
[
int
]
=
[
1
,
2
,
4
]
scaling_factors
:
l
ist
[
int
]
=
[
1
,
2
,
4
]
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
tuple
(
scaling_factors
)
...
...
@@ -234,7 +234,7 @@ def test_rope_module_cache():
})
settings
=
(
HEAD_SIZES
,
ROTARY_DIMS
,
MAX_POSITIONS
,
BASES
,
IS_NEOX_STYLE
,
ROPE_SCALINGS
,
DTYPES
)
rope_setting_id_map
:
D
ict
[
str
,
int
]
=
{}
rope_setting_id_map
:
d
ict
[
str
,
int
]
=
{}
for
setting
in
product
(
*
settings
):
head_size
,
rotary_dim
,
max_position
,
base
,
\
is_neox_stype
,
rope_scaling
,
dtype
=
setting
...
...
Prev
1
…
16
17
18
19
20
21
22
23
24
…
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment