Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
469e903b
Commit
469e903b
authored
Mar 28, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.2' into v0.8.2-dev
parents
389ebcf7
25f560a6
Changes
535
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
976 additions
and
362 deletions
+976
-362
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+91
-24
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+90
-60
tests/kernels/test_awq_marlin.py
tests/kernels/test_awq_marlin.py
+2
-7
tests/kernels/test_block_fp8.py
tests/kernels/test_block_fp8.py
+2
-2
tests/kernels/test_blocksparse_attention.py
tests/kernels/test_blocksparse_attention.py
+6
-6
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+96
-41
tests/kernels/test_cascade_flash_attn.py
tests/kernels/test_cascade_flash_attn.py
+4
-4
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+32
-34
tests/kernels/test_cutlass_2of4_sparse.py
tests/kernels/test_cutlass_2of4_sparse.py
+2
-3
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+15
-13
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+171
-131
tests/kernels/test_flashmla.py
tests/kernels/test_flashmla.py
+132
-0
tests/kernels/test_fused_quant_layernorm.py
tests/kernels/test_fused_quant_layernorm.py
+6
-6
tests/kernels/test_gguf.py
tests/kernels/test_gguf.py
+77
-7
tests/kernels/test_machete_mm.py
tests/kernels/test_machete_mm.py
+7
-7
tests/kernels/test_mamba_mixer2.py
tests/kernels/test_mamba_mixer2.py
+1
-2
tests/kernels/test_mamba_ssm_ssd.py
tests/kernels/test_mamba_ssm_ssd.py
+3
-5
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+86
-7
tests/kernels/test_nvfp4_scaled_mm.py
tests/kernels/test_nvfp4_scaled_mm.py
+150
-0
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+3
-3
No files found.
Too many changes to show.
To preserve performance only
535 of 535+
files are displayed.
Plain diff
Email patch
tests/kernels/test_attention.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
random
import
random
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -17,6 +17,8 @@ if not current_platform.is_rocm():
...
@@ -17,6 +17,8 @@ if not current_platform.is_rocm():
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
vllm.attention.backends.xformers
import
_make_alibi_bias
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
# This will change depending on the compute capability.
# This will change depending on the compute capability.
# - 512 as a buffer
# - 512 as a buffer
...
@@ -25,6 +27,7 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
...
@@ -25,6 +27,7 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
# Reduce NUM_BLOCKS when it happens.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS
=
4321
# Arbitrary values for testing
NUM_BLOCKS
=
4321
# Arbitrary values for testing
PARTITION_SIZE
=
512
PARTITION_SIZE
=
512
PARTITION_SIZE_ROCM
=
256
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES
=
[
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
...
@@ -85,8 +88,8 @@ def ref_single_query_cached_kv_attention(
...
@@ -85,8 +88,8 @@ def ref_single_query_cached_kv_attention(
block_table
=
block_tables_lst
[
i
]
block_table
=
block_tables_lst
[
i
]
seq_len
=
int
(
seq_lens_lst
[
i
])
seq_len
=
int
(
seq_lens_lst
[
i
])
keys_lst
:
L
ist
[
torch
.
Tensor
]
=
[]
keys_lst
:
l
ist
[
torch
.
Tensor
]
=
[]
values_lst
:
L
ist
[
torch
.
Tensor
]
=
[]
values_lst
:
l
ist
[
torch
.
Tensor
]
=
[]
for
j
in
range
(
seq_len
):
for
j
in
range
(
seq_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
block_offset
=
j
%
block_size
...
@@ -133,7 +136,7 @@ def test_paged_attention(
...
@@ -133,7 +136,7 @@ def test_paged_attention(
kv_cache_factory
,
kv_cache_factory
,
version
:
str
,
version
:
str
,
num_seqs
:
int
,
num_seqs
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
use_alibi
:
bool
,
use_alibi
:
bool
,
block_size
:
int
,
block_size
:
int
,
...
@@ -146,6 +149,8 @@ def test_paged_attention(
...
@@ -146,6 +149,8 @@ def test_paged_attention(
or
(
version
==
"rocm"
and
head_size
not
in
(
64
,
128
))):
or
(
version
==
"rocm"
and
head_size
not
in
(
64
,
128
))):
pytest
.
skip
()
pytest
.
skip
()
global
PARTITION_SIZE
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
...
@@ -166,7 +171,7 @@ def test_paged_attention(
...
@@ -166,7 +171,7 @@ def test_paged_attention(
# Create the block tables.
# Create the block tables.
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables_lst
:
L
ist
[
L
ist
[
int
]]
=
[]
block_tables_lst
:
l
ist
[
l
ist
[
int
]]
=
[]
for
_
in
range
(
num_seqs
):
for
_
in
range
(
num_seqs
):
block_table
=
[
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
...
@@ -214,6 +219,9 @@ def test_paged_attention(
...
@@ -214,6 +219,9 @@ def test_paged_attention(
and
block_size
==
BLOCK_SIZES
[
0
]))
and
block_size
==
BLOCK_SIZES
[
0
]))
elif
version
in
(
"v2"
,
"rocm"
):
elif
version
in
(
"v2"
,
"rocm"
):
if
current_platform
.
is_rocm
()
and
version
==
"rocm"
:
PARTITION_SIZE
=
PARTITION_SIZE_ROCM
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
assert
PARTITION_SIZE
%
block_size
==
0
assert
PARTITION_SIZE
%
block_size
==
0
num_seqs
,
num_heads
,
head_size
=
output
.
shape
num_seqs
,
num_heads
,
head_size
=
output
.
shape
...
@@ -334,25 +342,31 @@ def test_paged_attention(
...
@@ -334,25 +342,31 @@ def test_paged_attention(
def
ref_multi_query_kv_attention
(
def
ref_multi_query_kv_attention
(
cu_seq_lens
:
L
ist
[
int
],
cu_seq_lens
:
l
ist
[
int
],
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
scale
:
float
,
alibi_bias
:
Optional
[
list
[
torch
.
Tensor
]],
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_seqs
=
len
(
cu_seq_lens
)
-
1
num_seqs
=
len
(
cu_seq_lens
)
-
1
ref_outputs
:
List
[
torch
.
Tensor
]
=
[]
ref_outputs
:
list
[
torch
.
Tensor
]
=
[]
if
alibi_bias
:
assert
len
(
alibi_bias
)
==
num_seqs
for
i
in
range
(
num_seqs
):
for
i
in
range
(
num_seqs
):
start_idx
=
cu_seq_lens
[
i
]
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
end_idx
=
cu_seq_lens
[
i
+
1
]
seq_len
=
end_idx
-
start_idx
seq_len
=
end_idx
-
start_idx
# Create attention mask.
# Create attention mask. ALiBi already includes a tril causal mask.
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
dtype
),
if
alibi_bias
:
diagonal
=
1
)
attn_mask
=
alibi_bias
[
i
]
attn_mask
=
attn_mask
*
torch
.
finfo
(
dtype
).
min
else
:
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
)
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
dtype
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
dtype
).
min
attn_mask
=
attn_mask
.
to
(
dtype
=
dtype
)
ref_output
=
ref_masked_attention
(
ref_output
=
ref_masked_attention
(
query
[
start_idx
:
end_idx
],
query
[
start_idx
:
end_idx
],
...
@@ -366,7 +380,6 @@ def ref_multi_query_kv_attention(
...
@@ -366,7 +380,6 @@ def ref_multi_query_kv_attention(
return
torch
.
cat
(
ref_outputs
,
dim
=
0
)
return
torch
.
cat
(
ref_outputs
,
dim
=
0
)
# TODO(woosuk): Add tests for USE_ALIBI=True.
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_PREFILL_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_PREFILL_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
...
@@ -378,11 +391,12 @@ def ref_multi_query_kv_attention(
...
@@ -378,11 +391,12 @@ def ref_multi_query_kv_attention(
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_multi_query_kv_attention
(
def
test_multi_query_kv_attention
(
num_seqs
:
int
,
num_seqs
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
use_alibi
:
bool
=
False
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
...
@@ -408,16 +422,40 @@ def test_multi_query_kv_attention(
...
@@ -408,16 +422,40 @@ def test_multi_query_kv_attention(
# Handle MQA and GQA
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_queries_per_kv
,
dim
=
1
)
key
=
torch
.
repeat_interleave
(
key
,
num_queries_per_kv
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
num_queries_per_kv
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
num_queries_per_kv
,
dim
=
1
)
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_lens
)
alibi_bias
=
None
output
=
xops
.
memory_efficient_attention_forward
(
if
use_alibi
:
query
.
unsqueeze
(
0
),
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
)
key
.
unsqueeze
(
0
),
attn_bias
=
_make_alibi_bias
(
alibi_slopes
,
num_kv_heads
,
dtype
,
value
.
unsqueeze
(
0
),
seq_lens
)
attn_bias
=
attn_bias
,
output
=
torch
.
empty_like
(
query
)
p
=
0.0
,
start
=
0
scale
=
scale
,
# Dynamic sequence length not supported with custom attn_bias.
)
for
i
,
seq_len
in
enumerate
(
seq_lens
):
output
=
output
.
squeeze
(
0
)
end
=
start
+
seq_len
out
=
xops
.
memory_efficient_attention_forward
(
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
value
[
None
,
start
:
end
],
attn_bias
=
attn_bias
[
i
],
p
=
0.0
,
scale
=
scale
)
output
[
start
:
end
].
copy_
(
out
.
view_as
(
query
[
start
:
end
]))
start
+=
seq_len
# xformers.AttentionBias to Tensor for use in reference impl.
alibi_bias
=
[
b
.
materialize
(
b
.
shape
,
device
=
device
).
squeeze
()
for
b
in
attn_bias
]
else
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_lens
)
output
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
key
.
unsqueeze
(
0
),
value
.
unsqueeze
(
0
),
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
scale
,
)
output
=
output
.
squeeze
(
0
)
cu_seq_lens
=
[
0
]
cu_seq_lens
=
[
0
]
for
seq_len
in
seq_lens
:
for
seq_len
in
seq_lens
:
...
@@ -428,8 +466,37 @@ def test_multi_query_kv_attention(
...
@@ -428,8 +466,37 @@ def test_multi_query_kv_attention(
key
,
key
,
value
,
value
,
scale
,
scale
,
alibi_bias
,
dtype
,
dtype
,
)
)
atol
=
get_default_atol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-3
atol
=
get_default_atol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-5
rtol
=
get_default_rtol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-5
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
@
pytest
.
mark
.
parametrize
(
"num_seqs"
,
NUM_PREFILL_SEQS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Xformers backend is not supported on ROCm."
)
@
torch
.
inference_mode
()
def
test_multi_query_kv_attention_with_alibi
(
num_seqs
:
int
,
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
)
->
None
:
return
test_multi_query_kv_attention
(
num_seqs
,
num_heads
,
head_size
,
dtype
,
seed
,
device
,
use_alibi
=
True
,
)
tests/kernels/test_attention_selector.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
unittest.mock
import
Mock
,
patch
from
unittest.mock
import
patch
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm.attention.selector
import
_cached_get_attn_backend
,
get_attn_backend
from
vllm.attention.selector
import
_cached_get_attn_backend
,
get_attn_backend
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cpu
import
CpuPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.cuda
import
CudaPlatform
from
vllm.platforms.openvino
import
OpenVinoPlatform
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.platforms.rocm
import
RocmPlatform
from
vllm.utils
import
STR_FLASH_ATTN_VAL
,
STR_INVALID_VAL
from
vllm.utils
import
STR_BACKEND_ENV_VAR
,
STR_FLASH_ATTN_VAL
,
STR_INVALID_VAL
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -23,86 +22,117 @@ def clear_cache():
...
@@ -23,86 +22,117 @@ def clear_cache():
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
,
"OPENVINO"
]
if
not
current_platform
()
else
[
"ROCM_FLASH"
])
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
]
if
not
current_platform
.
is_rocm
()
else
[
"ROCM_FLASH"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"openvino"
,
"hip"
,
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"use_v1"
,
[
True
,
False
])
def
test_env
(
name
:
str
,
device
:
str
,
monkeypatch
):
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
def
test_env
(
name
:
str
,
use_v1
:
bool
,
device
:
str
,
monkeypatch
:
pytest
.
MonkeyPatch
,
):
"""Test that the attention selector can be set via environment variable.
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
Note that we do not test FlashAttn because it is the default backend.
"""
"""
override_backend_env_variable
(
monkeypatch
,
name
)
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
if
use_v1
else
"0"
)
if
device
==
"cpu"
:
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
name
)
with
patch
(
"vllm.attention.selector.current_platform"
,
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
if
device
==
"cpu"
:
False
)
assert
backend
.
get_name
()
==
"TORCH_SDPA"
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
RocmPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
"ROCM_FLASH"
elif
device
==
"openvino"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
OpenVinoPlatform
()),
patch
.
dict
(
'sys.modules'
,
{
'openvino'
:
Mock
()}):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
assert
backend
.
get_name
()
==
"OPENVINO"
else
:
if
name
in
[
"XFORMERS"
,
"FLASHINFER"
]:
with
patch
(
"vllm.attention.selector.current_platform"
,
with
patch
(
"vllm.attention.selector.current_platform"
,
Cu
da
Platform
()):
C
p
uPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
16
,
False
)
assert
backend
.
get_name
()
==
name
assert
backend
.
get_name
()
==
"TORCH_SDPA"
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
def
test_flash_attn
(
monkeypatch
):
RocmPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
EXPECTED
=
"TRITON_ATTN_VLLM_V1"
if
use_v1
else
"ROCM_FLASH"
assert
backend
.
get_name
()
==
EXPECTED
else
:
if
name
in
[
"XFORMERS"
,
"FLASHINFER"
]:
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
EXPECTED
=
"FLASH_ATTN_VLLM_V1"
if
use_v1
else
name
assert
backend
.
get_name
()
==
EXPECTED
def
test_flash_attn
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test FlashAttn validation."""
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to
# TODO: When testing for v1, pipe in `use_v1` as an argument to
# get_attn_backend
# get_attn_backend
override_backend_env_variable
(
monkeypatch
,
STR_FLASH_ATTN_VAL
)
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
STR_FLASH_ATTN_VAL
)
# Unsupported CUDA arch
# Unsupported CUDA arch
with
patch
(
"torch.cuda.get_device_capability"
,
return_value
=
(
7
,
5
)):
monkeypatch
.
setattr
(
torch
.
cuda
,
"get_device_capability"
,
lambda
:
(
7
,
5
))
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Unsupported data type
# Reset the monkeypatch for subsequent tests
backend
=
get_attn_backend
(
16
,
torch
.
float8_e4m3fn
,
None
,
16
,
False
)
monkeypatch
.
undo
()
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Unsupported
kv cache
data type
# Unsupported data type
backend
=
get_attn_backend
(
16
,
torch
.
float
16
,
"fp8"
,
16
,
False
)
backend
=
get_attn_backend
(
16
,
torch
.
float
8_e4m3fn
,
None
,
16
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Unsupported
block siz
e
# Unsupported
kv cache data typ
e
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
8
,
False
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
"fp8"
,
16
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# flash-attn is not installed
# Unsupported block size
with
patch
.
dict
(
'sys.modules'
,
{
'vllm_flash_attn'
:
None
}):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
8
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# flash-attn is not installed
import
sys
original_module
=
sys
.
modules
.
get
(
'vllm_flash_attn'
)
monkeypatch
.
setitem
(
sys
.
modules
,
'vllm_flash_attn'
,
None
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Unsupported head size
# Restore the original module if it existed
backend
=
get_attn_backend
(
17
,
torch
.
float16
,
None
,
16
,
False
)
if
original_module
is
not
None
:
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
monkeypatch
.
setitem
(
sys
.
modules
,
'vllm_flash_attn'
,
original_module
)
else
:
monkeypatch
.
delitem
(
sys
.
modules
,
'vllm_flash_attn'
,
raising
=
False
)
# Unsupported head size
backend
=
get_attn_backend
(
17
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
True
)
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
@
pytest
.
mark
.
parametrize
(
"use_v1"
,
[
True
,
False
])
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
True
)
def
test_invalid_env
(
use_v1
:
bool
,
monkeypatch
:
pytest
.
MonkeyPatch
):
assert
backend
.
get_name
()
!=
STR_FLASH_ATTN_VAL
with
monkeypatch
.
context
()
as
m
,
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
if
use_v1
else
"0"
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
STR_INVALID_VAL
)
def
test_invalid_env
(
monkeypatch
):
# Test with head size 32
"""Ignore the invalid env variable if it is set."""
override_backend_env_variable
(
monkeypatch
,
STR_INVALID_VAL
)
with
patch
(
"vllm.attention.selector.current_platform"
,
CudaPlatform
()):
backend
=
get_attn_backend
(
32
,
torch
.
float16
,
None
,
16
,
False
)
backend
=
get_attn_backend
(
32
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
==
"FLASH_ATTN"
EXPECTED
=
"FLASH_ATTN_VLLM_V1"
if
use_v1
else
"FLASH_ATTN"
assert
backend
.
get_name
()
==
EXPECTED
# when block size == 16, backend will fall back to XFORMERS
# when block size == 16, backend will fall back to XFORMERS
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
# this behavior is not yet supported on V1.
assert
backend
.
get_name
()
==
"XFORMERS"
if
use_v1
:
# TODO: support fallback on V1!
# https://github.com/vllm-project/vllm/issues/14524
pass
else
:
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
get_name
()
==
"XFORMERS"
tests/kernels/test_awq_marlin.py
View file @
469e903b
...
@@ -99,13 +99,8 @@ def test_fused_marlin_moe_awq(
...
@@ -99,13 +99,8 @@ def test_fused_marlin_moe_awq(
num_bits
=
num_bits
,
num_bits
=
num_bits
,
)
)
torch_output
=
torch_moe
(
torch_output
=
torch_moe
(
a
,
w_ref1
.
transpose
(
1
,
2
),
w_ref2
.
transpose
(
1
,
2
),
a
,
score
,
topk
,
None
)
w_ref1
.
transpose
(
1
,
2
),
w_ref2
.
transpose
(
1
,
2
),
score
,
topk
,
)
assert
compute_max_diff
(
marlin_output
,
torch_output
)
<
4e-2
assert
compute_max_diff
(
marlin_output
,
torch_output
)
<
4e-2
...
...
tests/kernels/test_block_fp8.py
View file @
469e903b
...
@@ -30,8 +30,8 @@ M_moe = [1, 7, 83, 512, 2048]
...
@@ -30,8 +30,8 @@ M_moe = [1, 7, 83, 512, 2048]
N_moe
=
[
4608
]
# [128, 4608, 13824]
N_moe
=
[
4608
]
# [128, 4608, 13824]
K_moe
=
[
7168
]
# [256, 7168, 13824]
K_moe
=
[
7168
]
# [256, 7168, 13824]
BLOCK_SIZE
=
[[
128
,
128
]]
BLOCK_SIZE
=
[[
128
,
128
]]
E
=
[
256
]
# [8, 24, 128, 256]
E
=
[
8
,
24
]
# [8, 24, 128, 256]
TOP_KS
=
[
1
]
# [1, 2, 6]
TOP_KS
=
[
2
]
# [1, 2, 6]
OUT_DTYPES
=
[
torch
.
bfloat16
]
# [torch.float32, torch.half, torch.bfloat16]
OUT_DTYPES
=
[
torch
.
bfloat16
]
# [torch.float32, torch.half, torch.bfloat16]
SEEDS
=
[
0
]
SEEDS
=
[
0
]
...
...
tests/kernels/test_blocksparse_attention.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
random
import
random
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -87,8 +87,8 @@ def ref_single_query_cached_kv_attention(
...
@@ -87,8 +87,8 @@ def ref_single_query_cached_kv_attention(
block_table
=
block_tables_lst
[
i
]
block_table
=
block_tables_lst
[
i
]
seq_len
=
int
(
seq_lens_lst
[
i
])
seq_len
=
int
(
seq_lens_lst
[
i
])
keys_lst
:
L
ist
[
torch
.
Tensor
]
=
[]
keys_lst
:
l
ist
[
torch
.
Tensor
]
=
[]
values_lst
:
L
ist
[
torch
.
Tensor
]
=
[]
values_lst
:
l
ist
[
torch
.
Tensor
]
=
[]
for
j
in
range
(
seq_len
):
for
j
in
range
(
seq_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
block_offset
=
j
%
block_size
...
@@ -162,7 +162,7 @@ def test_paged_attention(
...
@@ -162,7 +162,7 @@ def test_paged_attention(
kv_cache_factory
,
kv_cache_factory
,
version
:
str
,
version
:
str
,
num_seqs
:
int
,
num_seqs
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
use_alibi
:
bool
,
use_alibi
:
bool
,
block_size
:
int
,
block_size
:
int
,
...
@@ -331,7 +331,7 @@ def test_paged_attention(
...
@@ -331,7 +331,7 @@ def test_paged_attention(
def
ref_multi_query_kv_attention
(
def
ref_multi_query_kv_attention
(
cu_seq_lens
:
L
ist
[
int
],
cu_seq_lens
:
l
ist
[
int
],
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
@@ -376,7 +376,7 @@ def ref_multi_query_kv_attention(
...
@@ -376,7 +376,7 @@ def ref_multi_query_kv_attention(
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_varlen_blocksparse_attention_prefill
(
def
test_varlen_blocksparse_attention_prefill
(
num_seqs
:
int
,
num_seqs
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
blocksparse_local_blocks
:
int
,
blocksparse_local_blocks
:
int
,
blocksparse_vert_stride
:
int
,
blocksparse_vert_stride
:
int
,
...
...
tests/kernels/test_cache.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
random
import
random
from
typing
import
List
,
Tuple
import
pytest
import
pytest
import
torch
import
torch
...
@@ -9,7 +8,6 @@ import torch
...
@@ -9,7 +8,6 @@ import torch
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
align_to_256bytes
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
@@ -75,7 +73,7 @@ def test_copy_blocks(
...
@@ -75,7 +73,7 @@ def test_copy_blocks(
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
dst_blocks
=
random
.
sample
(
remainig_blocks
,
2
*
num_mappings
)
dst_blocks
=
random
.
sample
(
remainig_blocks
,
2
*
num_mappings
)
block_mapping
:
L
ist
[
T
uple
[
int
,
int
]]
=
[]
block_mapping
:
l
ist
[
t
uple
[
int
,
int
]]
=
[]
for
i
in
range
(
num_mappings
):
for
i
in
range
(
num_mappings
):
src
=
src_blocks
[
i
]
src
=
src_blocks
[
i
]
dst1
=
dst_blocks
[
2
*
i
]
dst1
=
dst_blocks
[
2
*
i
]
...
@@ -160,19 +158,20 @@ def test_reshape_and_cache(
...
@@ -160,19 +158,20 @@ def test_reshape_and_cache(
device
)
device
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Using default kv_scale
k_scale
=
(
key
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
# Clone the KV caches.
# Clone the KV caches.
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
,
k_scale
.
item
()
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
,
v_scale
.
item
()
)
else
:
else
:
cloned_key_cache
=
key_cache
.
clone
()
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Using default kv_scale
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
# Call the reshape_and_cache kernel.
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
...
@@ -183,9 +182,9 @@ def test_reshape_and_cache(
...
@@ -183,9 +182,9 @@ def test_reshape_and_cache(
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
,
k_scale
.
item
()
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
,
v_scale
.
item
()
)
# Run the reference implementation.
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
...
@@ -269,15 +268,16 @@ def test_reshape_and_cache_flash(
...
@@ -269,15 +268,16 @@ def test_reshape_and_cache_flash(
del
key_caches
del
key_caches
del
value_caches
del
value_caches
k_scale
=
(
key
.
amax
()
/
25
6.0
).
to
(
torch
.
float32
)
k_scale
=
(
key
.
amax
()
/
6
4
.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
25
6.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
6
4
.0
).
to
(
torch
.
float32
)
# Clone the KV caches.
# Clone the KV caches.
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
,
k_scale
,
kv_cache_dtype
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
,
k_scale
.
item
(),
kv_cache_dtype
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
,
v_scale
,
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
,
v_scale
.
item
()
,
kv_cache_dtype
)
kv_cache_dtype
)
else
:
else
:
cloned_key_cache
=
key_cache
.
clone
()
cloned_key_cache
=
key_cache
.
clone
()
...
@@ -341,7 +341,7 @@ def test_reshape_and_cache_flash(
...
@@ -341,7 +341,7 @@ def test_reshape_and_cache_flash(
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_swap_blocks
(
def
test_swap_blocks
(
kv_cache_factory
,
kv_cache_factory
,
direction
:
T
uple
[
str
,
str
],
direction
:
t
uple
[
str
,
str
],
num_mappings
:
int
,
num_mappings
:
int
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
...
@@ -452,22 +452,13 @@ def _create_mla_cache(
...
@@ -452,22 +452,13 @@ def _create_mla_cache(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
device
:
str
,
device
:
str
,
align_cache
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
cache_dtype
=
torch
.
uint8
if
kv_cache_dtype
==
"fp8"
else
dtype
cache_dtype
=
torch
.
uint8
if
kv_cache_dtype
==
"fp8"
else
dtype
return
torch
.
zeros
(
num_blocks
,
if
align_cache
:
block_size
,
alloc_entry_size
=
align_to_256bytes
(
entry_size
,
cache_dtype
)
entry_size
,
alloc_shape
=
(
num_blocks
,
block_size
,
alloc_entry_size
)
dtype
=
cache_dtype
,
cache_full
=
torch
.
zeros
(
alloc_shape
,
dtype
=
cache_dtype
,
device
=
device
)
device
=
device
)
cache
=
cache_full
[...,
:
entry_size
]
else
:
cache
=
torch
.
zeros
(
num_blocks
,
block_size
,
entry_size
,
dtype
=
cache_dtype
,
device
=
device
)
return
cache
def
_fill_mla_cache
(
cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
):
def
_fill_mla_cache
(
cache
:
torch
.
Tensor
,
kv_cache_dtype
:
str
):
...
@@ -490,7 +481,6 @@ def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
...
@@ -490,7 +481,6 @@ def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"align_cache"
,
[
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_concat_and_cache_mla
(
def
test_concat_and_cache_mla
(
kv_lora_rank
:
int
,
kv_lora_rank
:
int
,
...
@@ -502,7 +492,6 @@ def test_concat_and_cache_mla(
...
@@ -502,7 +492,6 @@ def test_concat_and_cache_mla(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
align_cache
:
bool
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
...
@@ -522,7 +511,7 @@ def test_concat_and_cache_mla(
...
@@ -522,7 +511,7 @@ def test_concat_and_cache_mla(
scale
=
torch
.
tensor
(
0.1
,
dtype
=
torch
.
float32
,
device
=
device
)
scale
=
torch
.
tensor
(
0.1
,
dtype
=
torch
.
float32
,
device
=
device
)
kv_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
ref_temp
=
torch
.
zeros
(
*
kv_cache
.
shape
,
dtype
=
dtype
,
device
=
device
)
ref_temp
=
torch
.
zeros
(
*
kv_cache
.
shape
,
dtype
=
dtype
,
device
=
device
)
for
i
in
range
(
num_tokens
):
for
i
in
range
(
num_tokens
):
...
@@ -578,7 +567,6 @@ def test_concat_and_cache_mla(
...
@@ -578,7 +567,6 @@ def test_concat_and_cache_mla(
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"align_cache"
,
[
False
,
True
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_copy_blocks_mla
(
def
test_copy_blocks_mla
(
kv_lora_rank
:
int
,
kv_lora_rank
:
int
,
...
@@ -590,7 +578,6 @@ def test_copy_blocks_mla(
...
@@ -590,7 +578,6 @@ def test_copy_blocks_mla(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
align_cache
:
bool
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
...
@@ -600,7 +587,7 @@ def test_copy_blocks_mla(
...
@@ -600,7 +587,7 @@ def test_copy_blocks_mla(
kv_caches
=
[]
kv_caches
=
[]
for
_
in
range
(
num_layers
):
for
_
in
range
(
num_layers
):
kv_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
_fill_mla_cache
(
kv_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
_fill_mla_cache
(
kv_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
kv_caches
.
append
(
kv_cache
)
kv_caches
.
append
(
kv_cache
)
...
@@ -644,7 +631,6 @@ def test_copy_blocks_mla(
...
@@ -644,7 +631,6 @@ def test_copy_blocks_mla(
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"align_cache"
,
[
False
,
True
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_swap_blocks_mla
(
def
test_swap_blocks_mla
(
kv_lora_rank
:
int
,
kv_lora_rank
:
int
,
...
@@ -655,7 +641,6 @@ def test_swap_blocks_mla(
...
@@ -655,7 +641,6 @@ def test_swap_blocks_mla(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
align_cache
:
bool
,
)
->
None
:
)
->
None
:
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
...
@@ -663,9 +648,9 @@ def test_swap_blocks_mla(
...
@@ -663,9 +648,9 @@ def test_swap_blocks_mla(
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
dst_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
dst_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
,
align_cache
)
kv_cache_dtype
,
device
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
)
_fill_mla_cache
(
dst_cache
,
kv_cache_dtype
)
_fill_mla_cache
(
dst_cache
,
kv_cache_dtype
)
...
@@ -685,8 +670,6 @@ def test_swap_blocks_mla(
...
@@ -685,8 +670,6 @@ def test_swap_blocks_mla(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
(
src_cache
,
dst_cache
,
block_mapping_tensor
),
(
src_cache
,
dst_cache
,
block_mapping_tensor
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
cond
=
(
kv_lora_rank
==
KV_LORA_RANKS
[
0
]
and
qk_rope_head_dim
==
QK_ROPE_HEAD_DIMS
[
0
]),
)
)
ops
.
swap_blocks
(
src_cache
,
dst_cache
,
block_mapping_tensor
)
ops
.
swap_blocks
(
src_cache
,
dst_cache
,
block_mapping_tensor
)
...
@@ -697,3 +680,75 @@ def test_swap_blocks_mla(
...
@@ -697,3 +680,75 @@ def test_swap_blocks_mla(
dst_cache
[
dst
].
cpu
(),
dst_cache
[
dst
].
cpu
(),
msg
=
f
"Block
{
src
}
from src should have been swapped to block "
msg
=
f
"Block
{
src
}
from src should have been swapped to block "
f
"
{
dst
}
in dst_cache."
)
f
"
{
dst
}
in dst_cache."
)
@
pytest
.
mark
.
parametrize
(
"kv_lora_rank"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"qk_rope_head_dim"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
])
# You can also test "fp8" if needed.
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_gather_cache_mla
(
kv_lora_rank
,
qk_rope_head_dim
,
block_size
,
num_blocks
,
max_seq_len
,
batch_size
,
dtype
,
kv_cache_dtype
,
device
):
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
seq_len_tensor
=
torch
.
randint
(
0
,
max_seq_len
+
1
,
(
batch_size
,
),
device
=
device
)
total_tokens
=
seq_len_tensor
.
sum
()
cu_seq_lens
=
torch
.
empty
((
batch_size
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
cu_seq_lens
[
0
]
=
0
cu_seq_lens
[
1
:]
=
seq_len_tensor
.
cumsum
(
dim
=
0
).
to
(
dtype
=
torch
.
int32
)
print
(
"seq_len_tensor"
,
seq_len_tensor
)
tot_blocks_tensor
=
(
seq_len_tensor
+
block_size
-
1
)
//
block_size
block_table
=
torch
.
empty
((
batch_size
,
num_blocks
),
dtype
=
torch
.
int32
,
device
=
device
)
for
b
in
range
(
batch_size
):
perm
=
torch
.
randperm
(
num_blocks
,
device
=
device
)
block_table
[
b
,
:]
=
perm
dst
=
torch
.
zeros
((
total_tokens
,
entry_size
),
dtype
=
src_cache
.
dtype
,
device
=
device
)
expected_batches
=
[]
for
b
in
range
(
batch_size
):
s
=
seq_len_tensor
[
b
]
if
s
==
0
:
continue
tot
=
tot_blocks_tensor
[
b
]
blocks
=
block_table
[
b
,
:
tot
].
tolist
()
gathered_rows
=
[]
for
i
in
range
(
tot
-
1
):
gathered_rows
.
append
(
src_cache
[
blocks
[
i
]])
remaining
=
s
-
(
tot
-
1
)
*
block_size
gathered_rows
.
append
(
src_cache
[
blocks
[
-
1
],
:
remaining
,
:])
batch_expected
=
torch
.
cat
(
gathered_rows
,
dim
=
0
)
expected_batches
.
append
(
batch_expected
)
expected
=
torch
.
cat
(
expected_batches
,
dim
=
0
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
gather_cache
,
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
None
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
)
ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
)
torch
.
testing
.
assert_close
(
dst
,
expected
)
tests/kernels/test_cascade_flash_attn.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -25,7 +25,7 @@ DTYPES = [torch.float16, torch.bfloat16]
...
@@ -25,7 +25,7 @@ DTYPES = [torch.float16, torch.bfloat16]
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_merge_kernel
(
def
test_merge_kernel
(
num_tokens
:
int
,
num_tokens
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
):
):
...
@@ -85,8 +85,8 @@ CASES = [
...
@@ -85,8 +85,8 @@ CASES = [
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_cascade
(
def
test_cascade
(
seq_lens_and_common_prefix
:
T
uple
[
L
ist
[
T
uple
[
int
,
int
]],
int
],
seq_lens_and_common_prefix
:
t
uple
[
l
ist
[
t
uple
[
int
,
int
]],
int
],
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
block_size
:
int
,
...
...
tests/kernels/test_cutlass.py
View file @
469e903b
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_cutlass.py`.
Run `pytest tests/kernels/test_cutlass.py`.
"""
"""
from
typing
import
Type
,
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -82,7 +81,7 @@ def cutlass_fp8_gemm_helper(m: int,
...
@@ -82,7 +81,7 @@ def cutlass_fp8_gemm_helper(m: int,
a_scale_group_shape
:
tuple
,
a_scale_group_shape
:
tuple
,
b_scale_group_shape
:
tuple
,
b_scale_group_shape
:
tuple
,
use_bias
:
bool
,
use_bias
:
bool
,
out_dtype
:
T
ype
[
torch
.
dtype
]
=
torch
.
bfloat16
,
out_dtype
:
t
ype
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
device
:
str
=
"cuda"
):
# Test for a cutlass kernel with per-token activation quantization
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
# and per-output channel weight quantization.
...
@@ -120,7 +119,7 @@ def cutlass_int8_gemm_helper(m: int,
...
@@ -120,7 +119,7 @@ def cutlass_int8_gemm_helper(m: int,
a_scale_group_shape
:
tuple
,
a_scale_group_shape
:
tuple
,
b_scale_group_shape
:
tuple
,
b_scale_group_shape
:
tuple
,
use_bias
:
bool
,
use_bias
:
bool
,
out_dtype
:
T
ype
[
torch
.
dtype
]
=
torch
.
bfloat16
,
out_dtype
:
t
ype
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
device
:
str
=
"cuda"
):
# Test for a cutlass kernel with per-token activation quantization
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
# and per-output channel weight quantization.
...
@@ -198,7 +197,7 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
...
@@ -198,7 +197,7 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_output_dtype
(
a_scale_group_shape
,
def
test_cutlass_int8_gemm_output_dtype
(
a_scale_group_shape
,
b_scale_group_shape
,
b_scale_group_shape
,
out_dtype
:
T
ype
[
torch
.
dtype
],
out_dtype
:
t
ype
[
torch
.
dtype
],
use_bias
:
bool
):
use_bias
:
bool
):
cutlass_int8_gemm_helper
(
512
,
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
...
@@ -208,26 +207,25 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
...
@@ -208,26 +207,25 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
use_bias
,
use_bias
,
out_dtype
=
out_dtype
)
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"a_scale_group_shape"
,
# @pytest.mark.parametrize("a_scale_group_shape",
[
PER_TOKEN_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
# [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@
pytest
.
mark
.
parametrize
(
"b_scale_group_shape"
,
# @pytest.mark.parametrize("b_scale_group_shape",
[
PER_OUT_CH_GROUP_SHAPE
,
TENSORWISE_GROUP_SHAPE
])
# [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
# @pytest.mark.parametrize("use_bias", [True, False])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
reason
=
"FP8 is not supported on this GPU type."
)
# reason="FP8 is not supported on this GPU type.")
def
test_cutlass_fp8_gemm_output_dtype
(
a_scale_group_shape
,
# def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape
,
# b_scale_group_shape,
out_dtype
:
type
[
torch
.
dtype
],
# out_dtype: Type[torch.dtype],
use_bias
:
bool
):
# use_bias: bool):
cutlass_fp8_gemm_helper
(
512
,
# cutlass_fp8_gemm_helper(512,
512
,
# 512,
512
,
# 512,
a_scale_group_shape
,
# a_scale_group_shape,
b_scale_group_shape
,
# b_scale_group_shape,
use_bias
,
# use_bias,
out_dtype
=
out_dtype
)
# out_dtype=out_dtype)
# @pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
# @pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
...
@@ -238,7 +236,7 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
...
@@ -238,7 +236,7 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
# reason="FP8 blockwise is not supported on this GPU type.")
# reason="FP8 blockwise is not supported on this GPU type.")
# def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
# def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
# b_scale_group_shape,
# b_scale_group_shape,
# out_dtype:
T
ype[torch.dtype],
# out_dtype:
t
ype[torch.dtype],
# use_bias: bool):
# use_bias: bool):
# cutlass_fp8_gemm_helper(512,
# cutlass_fp8_gemm_helper(512,
# 512,
# 512,
...
@@ -271,15 +269,15 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
...
@@ -271,15 +269,15 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("use_bias", [True, False])
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
# def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
use_bias
:
bool
,
device
:
str
):
#
use_bias: bool, device: str):
cutlass_int8_gemm_helper
(
512
,
#
cutlass_int8_gemm_helper(512,
512
,
#
512,
512
,
#
512,
a_scale_group_shape
,
#
a_scale_group_shape,
b_scale_group_shape
,
#
b_scale_group_shape,
use_bias
,
#
use_bias,
out_dtype
=
torch
.
bfloat16
,
#
out_dtype=torch.bfloat16,
device
=
device
)
#
device=device)
...
...
tests/kernels/test_cutlass_2of4_sparse.py
View file @
469e903b
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_semi_structured.py`.
Run `pytest tests/kernels/test_semi_structured.py`.
"""
"""
from
typing
import
Tuple
,
Type
import
pytest
import
pytest
import
torch
import
torch
...
@@ -79,7 +78,7 @@ def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
...
@@ -79,7 +78,7 @@ def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
def
make_rand_sparse_tensors
(
def
make_rand_sparse_tensors
(
dtype
:
torch
.
dtype
,
m
:
int
,
n
:
int
,
k
:
int
dtype
:
torch
.
dtype
,
m
:
int
,
n
:
int
,
k
:
int
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
a
=
torch
.
randn
((
m
,
k
),
device
=
'cuda'
)
a
=
torch
.
randn
((
m
,
k
),
device
=
'cuda'
)
b
=
torch
.
randn
((
n
,
k
),
device
=
'cuda'
).
t
()
b
=
torch
.
randn
((
n
,
k
),
device
=
'cuda'
).
t
()
...
@@ -167,7 +166,7 @@ MNK_FACTORS = [
...
@@ -167,7 +166,7 @@ MNK_FACTORS = [
@
pytest
.
mark
.
parametrize
(
"m, n, k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"m, n, k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_sparse_gemm
(
m
:
int
,
k
:
int
,
n
:
int
,
dtype
:
T
ype
[
torch
.
dtype
],
def
test_cutlass_sparse_gemm
(
m
:
int
,
k
:
int
,
n
:
int
,
dtype
:
t
ype
[
torch
.
dtype
],
use_bias
:
bool
):
use_bias
:
bool
):
# Create tensors
# Create tensors
...
...
tests/kernels/test_encoder_decoder_attn.py
View file @
469e903b
...
@@ -22,6 +22,16 @@ from vllm.config import VllmConfig, set_current_vllm_config
...
@@ -22,6 +22,16 @@ from vllm.config import VllmConfig, set_current_vllm_config
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
Encoder-decoder is only supported on V0, so set
VLLM_USE_V1=0 for all tests in the module.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
# List of support backends for encoder/decoder models
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]
HEAD_SIZES
=
[
64
,
256
]
HEAD_SIZES
=
[
64
,
256
]
...
@@ -243,7 +253,7 @@ def _decoder_attn_setup(
...
@@ -243,7 +253,7 @@ def _decoder_attn_setup(
test_pt
:
TestPoint
,
test_pt
:
TestPoint
,
test_rsrcs
:
TestResources
,
test_rsrcs
:
TestResources
,
block_base_addr
:
int
=
0
,
block_base_addr
:
int
=
0
,
)
->
T
uple
[
QKVInputs
,
PhaseTestParameters
,
PhaseTestParameters
,
int
]:
)
->
t
uple
[
QKVInputs
,
PhaseTestParameters
,
PhaseTestParameters
,
int
]:
'''
'''
Set up test vectors & data structures for self-attention test.
Set up test vectors & data structures for self-attention test.
...
@@ -421,7 +431,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
...
@@ -421,7 +431,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
test_pt
:
TestPoint
,
test_pt
:
TestPoint
,
test_rsrcs
:
TestResources
,
test_rsrcs
:
TestResources
,
block_base_addr
:
int
=
0
,
block_base_addr
:
int
=
0
,
)
->
T
uple
[
PhaseTestParameters
,
PhaseTestParameters
]:
)
->
t
uple
[
PhaseTestParameters
,
PhaseTestParameters
]:
'''
'''
Set up test vectors & data structures for cross-attention test.
Set up test vectors & data structures for cross-attention test.
...
@@ -644,11 +654,7 @@ def _run_encoder_attention_test(
...
@@ -644,11 +654,7 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
)
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
packed_qkv
.
query
.
device
),
attn_metadata
)
def
_run_decoder_self_attention_test
(
def
_run_decoder_self_attention_test
(
...
@@ -682,7 +688,6 @@ def _run_decoder_self_attention_test(
...
@@ -682,7 +688,6 @@ def _run_decoder_self_attention_test(
& attn_metadata
& attn_metadata
'''
'''
attn
=
test_rsrcs
.
attn
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
assert
packed_qkv
is
not
None
with
set_forward_context
(
attn_metadata
,
vllm_config
):
with
set_forward_context
(
attn_metadata
,
vllm_config
):
...
@@ -695,8 +700,7 @@ def _run_decoder_self_attention_test(
...
@@ -695,8 +700,7 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
)
kv_cache
,
attn_metadata
)
def
_run_encoder_decoder_cross_attention_test
(
def
_run_encoder_decoder_cross_attention_test
(
...
@@ -744,7 +748,6 @@ def _run_encoder_decoder_cross_attention_test(
...
@@ -744,7 +748,6 @@ def _run_encoder_decoder_cross_attention_test(
assert
decoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
assert
decoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
attn
=
test_rsrcs
.
attn
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
if
cross_test_params
is
None
:
if
cross_test_params
is
None
:
key
=
None
key
=
None
value
=
None
value
=
None
...
@@ -762,8 +765,7 @@ def _run_encoder_decoder_cross_attention_test(
...
@@ -762,8 +765,7 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
.
view
(
reshaped_query
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
key
,
value
,
kv_cache
,
return
attn
.
forward
(
reshaped_query
,
key
,
value
)
attn_metadata
)
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
...
...
tests/kernels/test_flash_attn.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -8,8 +8,8 @@ import torch
...
@@ -8,8 +8,8 @@ import torch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
if
current_platform
():
if
current_platform
.
is_rocm
():
import
flash_attn
from
flash_attn
import
flash_attn
_varlen_func
else
:
else
:
from
vllm.vllm_flash_attn
import
(
fa_version_unsupported_reason
,
from
vllm.vllm_flash_attn
import
(
fa_version_unsupported_reason
,
flash_attn_varlen_func
,
flash_attn_varlen_func
,
...
@@ -20,6 +20,7 @@ NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
...
@@ -20,6 +20,7 @@ NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES
=
[
128
,
256
]
HEAD_SIZES
=
[
128
,
256
]
BLOCK_SIZES
=
[
16
,
32
]
BLOCK_SIZES
=
[
16
,
32
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
QDTYPES
=
[
None
,
torch
.
float8_e4m3fn
]
# one value large enough to test overflow in index calculation.
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
# one value small enough to test the schema op check
NUM_BLOCKS
=
[
32768
,
2048
]
NUM_BLOCKS
=
[
32768
,
2048
]
...
@@ -29,8 +30,8 @@ def ref_paged_attn(
...
@@ -29,8 +30,8 @@ def ref_paged_attn(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
query_lens
:
L
ist
[
int
],
query_lens
:
l
ist
[
int
],
kv_lens
:
L
ist
[
int
],
kv_lens
:
l
ist
[
int
],
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
scale
:
float
,
scale
:
float
,
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
...
@@ -40,7 +41,7 @@ def ref_paged_attn(
...
@@ -40,7 +41,7 @@ def ref_paged_attn(
block_tables
=
block_tables
.
cpu
().
numpy
()
block_tables
=
block_tables
.
cpu
().
numpy
()
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
outputs
:
L
ist
[
torch
.
Tensor
]
=
[]
outputs
:
l
ist
[
torch
.
Tensor
]
=
[]
start_idx
=
0
start_idx
=
0
for
i
in
range
(
num_seqs
):
for
i
in
range
(
num_seqs
):
query_len
=
query_lens
[
i
]
query_len
=
query_lens
[
i
]
...
@@ -79,91 +80,124 @@ def ref_paged_attn(
...
@@ -79,91 +80,124 @@ def ref_paged_attn(
return
torch
.
cat
(
outputs
,
dim
=
0
)
return
torch
.
cat
(
outputs
,
dim
=
0
)
if
not
current_platform
():
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
@
pytest
.
mark
.
parametrize
(
"use_out"
,
[
True
,
False
])
reason
=
"flash_attn_with_paged_kv is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"use_out"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
def
test_flash_attn_with_paged_kv
(
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
QDTYPES
)
use_out
:
bool
,
@
torch
.
inference_mode
()
kv_lens
:
List
[
int
],
def
test_flash_attn_with_paged_kv
(
num_heads
:
Tuple
[
int
,
int
],
use_out
:
bool
,
head_size
:
int
,
kv_lens
:
list
[
int
],
dtype
:
torch
.
dtype
,
num_heads
:
tuple
[
int
,
int
],
block_size
:
int
,
head_size
:
int
,
soft_cap
:
Optional
[
float
],
dtype
:
torch
.
dtype
,
num_blocks
:
int
,
block_size
:
int
,
sliding_window
:
Optional
[
int
],
soft_cap
:
Optional
[
float
],
fa_version
:
int
,
num_blocks
:
int
,
)
->
None
:
sliding_window
:
Optional
[
int
],
torch
.
set_default_device
(
"cuda"
)
fa_version
:
int
,
if
not
is_fa_version_supported
(
fa_version
):
q_dtype
:
Optional
[
torch
.
dtype
],
pytest
.
skip
(
f
"Flash attention version
{
fa_version
}
not supported due "
)
->
None
:
f
"to:
\"
{
fa_version_unsupported_reason
(
fa_version
)
}
\"
"
)
torch
.
set_default_device
(
"cuda"
)
if
not
is_fa_version_supported
(
fa_version
):
current_platform
.
seed_everything
(
0
)
pytest
.
skip
(
f
"Flash attention version
{
fa_version
}
not supported due "
num_seqs
=
len
(
kv_lens
)
f
"to:
\"
{
fa_version_unsupported_reason
(
fa_version
)
}
\"
"
)
num_query_heads
=
num_heads
[
0
]
if
q_dtype
is
not
None
and
(
dtype
!=
torch
.
bfloat16
or
fa_version
==
2
):
num_kv_heads
=
num_heads
[
1
]
pytest
.
skip
(
"Flash attention with quantized inputs is only "
assert
num_query_heads
%
num_kv_heads
==
0
"supported on version 3 with bfloat16 base type"
)
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
current_platform
.
seed_everything
(
0
)
window_size
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
num_seqs
=
len
(
kv_lens
)
(
-
1
,
-
1
))
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
assert
num_query_heads
%
num_kv_heads
==
0
key_cache
=
torch
.
randn
(
num_blocks
,
max_kv_len
=
max
(
kv_lens
)
block_size
,
scale
=
head_size
**-
0.5
num_kv_heads
,
window_size
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
head_size
,
(
-
1
,
-
1
))
dtype
=
dtype
)
value_cache
=
torch
.
randn_like
(
key_cache
)
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
key_cache
=
torch
.
randn
(
num_blocks
,
block_size
,
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
num_kv_heads
,
block_tables
=
torch
.
randint
(
0
,
head_size
,
num_blocks
,
dtype
=
dtype
)
(
num_seqs
,
max_num_blocks_per_seq
),
value_cache
=
torch
.
randn_like
(
key_cache
)
dtype
=
torch
.
int32
)
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
q
=
query
.
unsqueeze
(
1
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
out
=
torch
.
empty_like
(
q
)
if
use_out
else
None
block_tables
=
torch
.
randint
(
0
,
output
=
flash_attn_with_kvcache
(
num_blocks
,
q
=
q
,
(
num_seqs
,
max_num_blocks_per_seq
),
k_cache
=
key_cache
,
dtype
=
torch
.
int32
)
v_cache
=
value_cache
,
out
=
out
,
q
=
query
.
unsqueeze
(
1
)
softmax_scale
=
scale
,
out
=
torch
.
empty_like
(
q
)
if
use_out
else
None
causal
=
True
,
block_table
=
block_tables
,
maybe_quantized_query
=
q
cache_seqlens
=
kv_lens_tensor
,
maybe_quantized_key_cache
=
key_cache
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
maybe_quantized_value_cache
=
value_cache
window_size
=
window_size
,
q_descale
=
None
fa_version
=
fa_version
,
k_descale
=
None
)
v_descale
=
None
output
=
output
if
not
use_out
else
out
if
q_dtype
is
not
None
:
output
=
output
.
squeeze
(
1
)
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query
=
query
.
to
(
q_dtype
)
ref_output
=
ref_paged_attn
(
query
=
query
,
maybe_quantized_key_cache
=
key_cache
.
to
(
q_dtype
)
key_cache
=
key_cache
,
maybe_quantized_value_cache
=
value_cache
.
to
(
q_dtype
)
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
scale_shape
=
(
num_seqs
,
num_kv_heads
)
kv_lens
=
kv_lens
,
q_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
block_tables
=
block_tables
,
k_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
scale
=
scale
,
v_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
soft_cap
=
soft_cap
,
sliding_window
=
sliding_window
)
output
=
flash_attn_with_kvcache
(
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
q
=
maybe_quantized_query
,
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
k_cache
=
maybe_quantized_key_cache
,
v_cache
=
maybe_quantized_value_cache
,
out
=
out
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
window_size
=
window_size
,
fa_version
=
fa_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
output
=
output
if
not
use_out
else
out
output
=
output
.
squeeze
(
1
)
atol
,
rtol
=
1.5e-2
,
1e-2
if
q_dtype
is
not
None
:
atol
,
rtol
=
1.5e-1
,
1.5e-1
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
,
sliding_window
=
sliding_window
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"varlen_with_paged_kv is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"use_out"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_out"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
[[(
1
,
1328
),
(
5
,
18
),
[[(
1
,
1328
),
(
5
,
18
),
...
@@ -176,11 +210,12 @@ if not current_platform():
...
@@ -176,11 +210,12 @@ if not current_platform():
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
QDTYPES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_varlen_with_paged_kv
(
def
test_varlen_with_paged_kv
(
use_out
:
bool
,
use_out
:
bool
,
seq_lens
:
L
ist
[
T
uple
[
int
,
int
]],
seq_lens
:
l
ist
[
t
uple
[
int
,
int
]],
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
...
@@ -188,11 +223,15 @@ def test_varlen_with_paged_kv(
...
@@ -188,11 +223,15 @@ def test_varlen_with_paged_kv(
soft_cap
:
Optional
[
float
],
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
num_blocks
:
int
,
fa_version
:
int
,
fa_version
:
int
,
q_dtype
:
Optional
[
torch
.
dtype
],
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
if
not
is_fa_version_supported
(
fa_version
):
if
not
is_fa_version_supported
(
fa_version
):
pytest
.
skip
(
f
"Flash attention version
{
fa_version
}
not supported due "
pytest
.
skip
(
f
"Flash attention version
{
fa_version
}
not supported due "
f
"to:
\"
{
fa_version_unsupported_reason
(
fa_version
)
}
\"
"
)
f
"to:
\"
{
fa_version_unsupported_reason
(
fa_version
)
}
\"
"
)
if
q_dtype
is
not
None
and
(
dtype
!=
torch
.
bfloat16
or
fa_version
==
2
):
pytest
.
skip
(
"Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type"
)
current_platform
.
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
seq_lens
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
...
@@ -219,9 +258,6 @@ def test_varlen_with_paged_kv(
...
@@ -219,9 +258,6 @@ def test_varlen_with_paged_kv(
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
cu_kv_lens
=
torch
.
tensor
([
0
]
+
kv_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
kv_lens
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
kv_lens
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
...
@@ -231,42 +267,43 @@ def test_varlen_with_paged_kv(
...
@@ -231,42 +267,43 @@ def test_varlen_with_paged_kv(
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
out
=
torch
.
empty_like
(
query
)
if
use_out
else
None
out
=
torch
.
empty_like
(
query
)
if
use_out
else
None
if
current_platform
():
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
out
=
out
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
# fa_version=fa_version,
)
else
:
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
out
=
out
,
cu_seqlens_q
=
cu_query_lens
,
seqused_k
=
kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
fa_version
=
fa_version
,
)
maybe_quantized_query
=
query
maybe_quantized_key_cache
=
key_cache
maybe_quantized_value_cache
=
value_cache
q_descale
=
None
k_descale
=
None
v_descale
=
None
if
q_dtype
is
not
None
:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query
=
query
.
to
(
q_dtype
)
maybe_quantized_key_cache
=
key_cache
.
to
(
q_dtype
)
maybe_quantized_value_cache
=
value_cache
.
to
(
q_dtype
)
scale_shape
=
(
num_seqs
,
num_kv_heads
)
q_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
k_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
v_descale
=
torch
.
ones
(
scale_shape
,
dtype
=
torch
.
float32
)
output
=
flash_attn_varlen_func
(
q
=
maybe_quantized_query
,
k
=
maybe_quantized_key_cache
,
v
=
maybe_quantized_value_cache
,
out
=
out
,
cu_seqlens_q
=
cu_query_lens
,
seqused_k
=
kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
fa_version
=
fa_version
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
)
output
=
output
if
not
use_out
else
out
output
=
output
if
not
use_out
else
out
ref_output
=
ref_paged_attn
(
ref_output
=
ref_paged_attn
(
...
@@ -280,5 +317,8 @@ def test_varlen_with_paged_kv(
...
@@ -280,5 +317,8 @@ def test_varlen_with_paged_kv(
sliding_window
=
sliding_window
,
sliding_window
=
sliding_window
,
soft_cap
=
soft_cap
,
soft_cap
=
soft_cap
,
)
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
atol
,
rtol
=
1.5e-2
,
1e-2
if
q_dtype
is
not
None
:
atol
,
rtol
=
1.5e-1
,
1.5e-1
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
\ No newline at end of file
tests/kernels/test_flashmla.py
0 → 100644
View file @
469e903b
# Adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla.py
# SPDX-License-Identifier: Apache-2.0
import
math
import
random
import
pytest
import
torch
import
triton
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
get_mla_metadata
,
is_flashmla_supported
)
def
cal_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
)
->
None
:
x
,
y
=
x
.
double
(),
y
.
double
()
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
().
item
()
/
max
(
(
x
*
x
+
y
*
y
).
sum
().
item
(),
1e-12
)
assert
cos_diff
<
1e-5
FLASH_MLA_UNSUPPORTED_REASON
=
is_flashmla_supported
()[
1
]
\
if
not
is_flashmla_supported
()[
0
]
else
"FlashMLA is supported"
@
pytest
.
mark
.
skipif
(
not
is_flashmla_supported
()[
0
],
reason
=
FLASH_MLA_UNSUPPORTED_REASON
)
@
pytest
.
mark
.
parametrize
(
"b"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"s_q"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"mean_sk"
,
[
4096
,
8192
])
@
pytest
.
mark
.
parametrize
(
"h_q"
,
[
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"h_kv"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
576
])
@
pytest
.
mark
.
parametrize
(
"dv"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
torch
.
inference_mode
()
def
test_flash_mla
(
b
,
s_q
,
mean_sk
,
h_q
,
h_kv
,
d
,
dv
,
block_size
,
causal
,
varlen
):
# TODO: parametrize using pytest
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
manual_seed
(
0
)
random
.
seed
(
0
)
print
(
f
"
{
b
=
}
,
{
s_q
=
}
,
{
mean_sk
=
}
,
{
h_q
=
}
,
{
h_kv
=
}
, "
f
"
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
"
)
cache_seqlens
=
torch
.
full
((
b
,
),
mean_sk
,
dtype
=
torch
.
int32
)
if
varlen
:
for
i
in
range
(
b
):
cache_seqlens
[
i
]
=
max
(
random
.
normalvariate
(
mean_sk
,
mean_sk
/
2
),
s_q
)
total_seqlens
=
cache_seqlens
.
sum
().
item
()
max_seqlen
=
cache_seqlens
.
max
().
item
()
max_seqlen_pad
=
triton
.
cdiv
(
max_seqlen
,
256
)
*
256
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
for
i
in
range
(
b
):
blocked_k
.
view
(
b
,
max_seqlen_pad
,
h_kv
,
d
)[
i
,
cache_seqlens
[
i
].
item
():]
=
float
(
"nan"
)
blocked_v
=
blocked_k
[...,
:
dv
]
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
def
flash_mla
():
return
flash_mla_with_kvcache
(
q
,
blocked_k
,
block_table
,
cache_seqlens
,
dv
,
tile_scheduler_metadata
,
num_splits
,
causal
=
causal
,
)
def
scaled_dot_product_attention
(
query
,
key
,
value
,
is_causal
=
False
):
query
=
query
.
float
()
key
=
key
.
float
()
value
=
value
.
float
()
key
=
key
.
repeat_interleave
(
h_q
//
h_kv
,
dim
=
0
)
value
=
value
.
repeat_interleave
(
h_q
//
h_kv
,
dim
=
0
)
attn_weight
=
query
@
key
.
transpose
(
-
2
,
-
1
)
/
math
.
sqrt
(
query
.
size
(
-
1
))
if
is_causal
:
s_q
=
query
.
shape
[
-
2
]
s_k
=
key
.
shape
[
-
2
]
attn_bias
=
torch
.
zeros
(
s_q
,
s_k
,
dtype
=
query
.
dtype
)
temp_mask
=
torch
.
ones
(
s_q
,
s_k
,
dtype
=
torch
.
bool
).
tril
(
diagonal
=
s_k
-
s_q
)
attn_bias
.
masked_fill_
(
temp_mask
.
logical_not
(),
float
(
"-inf"
))
attn_bias
.
to
(
query
.
dtype
)
attn_weight
+=
attn_bias
lse
=
attn_weight
.
logsumexp
(
dim
=-
1
)
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
,
dtype
=
torch
.
float32
)
return
attn_weight
@
value
,
lse
def
ref_mla
():
out
=
torch
.
empty
(
b
,
s_q
,
h_q
,
dv
,
dtype
=
torch
.
float32
)
lse
=
torch
.
empty
(
b
,
h_q
,
s_q
,
dtype
=
torch
.
float32
)
for
i
in
range
(
b
):
begin
=
i
*
max_seqlen_pad
end
=
begin
+
cache_seqlens
[
i
]
ref_O
,
LSE
=
scaled_dot_product_attention
(
q
[
i
].
transpose
(
0
,
1
),
blocked_k
.
view
(
-
1
,
h_kv
,
d
)[
begin
:
end
].
transpose
(
0
,
1
),
blocked_v
.
view
(
-
1
,
h_kv
,
dv
)[
begin
:
end
].
transpose
(
0
,
1
),
is_causal
=
causal
,
)
out
[
i
]
=
ref_O
.
transpose
(
0
,
1
)
lse
[
i
]
=
LSE
return
out
,
lse
out_flash
,
lse_flash
=
flash_mla
()
out_torch
,
lse_torch
=
ref_mla
()
cal_diff
(
out_flash
,
out_torch
,
"out"
)
cal_diff
(
lse_flash
,
lse_torch
,
"lse"
)
t
=
triton
.
testing
.
do_bench
(
flash_mla
,
fast_flush
=
False
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
"
f
"TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
tests/kernels/test_fused_quant_layernorm.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Union
import
pytest
import
pytest
import
torch
import
torch
...
@@ -39,7 +39,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
...
@@ -39,7 +39,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
def
ref_rms_norm
(
rms_norm_layer
:
RMSNorm
,
def
ref_rms_norm
(
rms_norm_layer
:
RMSNorm
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
])
\
residual
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
residual
is
not
None
:
if
residual
is
not
None
:
residual
=
residual
.
clone
()
residual
=
residual
.
clone
()
out
,
residual
=
rms_norm_layer
.
forward_native
(
x
,
residual
)
out
,
residual
=
rms_norm_layer
.
forward_native
(
x
,
residual
)
...
@@ -54,7 +54,7 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
...
@@ -54,7 +54,7 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
scale_ub
:
Optional
[
torch
.
Tensor
])
\
scale_ub
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
scale_ub
is
not
None
:
if
scale_ub
is
not
None
:
assert
quant_dtype
==
torch
.
float8_e4m3fn
assert
quant_dtype
==
torch
.
float8_e4m3fn
...
@@ -78,7 +78,7 @@ def ref_impl(rms_norm_layer: RMSNorm,
...
@@ -78,7 +78,7 @@ def ref_impl(rms_norm_layer: RMSNorm,
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
scale_ub
:
Optional
[
torch
.
Tensor
])
\
scale_ub
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
return
ref_dynamic_per_token_quant
(
rms_norm_layer
,
x
,
quant_dtype
,
return
ref_dynamic_per_token_quant
(
rms_norm_layer
,
x
,
quant_dtype
,
residual
,
scale_ub
)
residual
,
scale_ub
)
...
@@ -88,7 +88,7 @@ def ops_dynamic_per_token_quant(weight: torch.Tensor,
...
@@ -88,7 +88,7 @@ def ops_dynamic_per_token_quant(weight: torch.Tensor,
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
scale_ub
:
Optional
[
torch
.
Tensor
])
\
scale_ub
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
residual
is
not
None
:
if
residual
is
not
None
:
residual
=
residual
.
clone
()
residual
=
residual
.
clone
()
out
,
scales
=
ops
.
rms_norm_dynamic_per_token_quant
(
x
,
weight
,
EPS
,
out
,
scales
=
ops
.
rms_norm_dynamic_per_token_quant
(
x
,
weight
,
EPS
,
...
@@ -102,7 +102,7 @@ def ops_impl(weight: torch.Tensor,
...
@@ -102,7 +102,7 @@ def ops_impl(weight: torch.Tensor,
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
scale_ub
:
Optional
[
torch
.
Tensor
])
\
scale_ub
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
return
ops_dynamic_per_token_quant
(
weight
,
x
,
quant_dtype
,
residual
,
return
ops_dynamic_per_token_quant
(
weight
,
x
,
quant_dtype
,
residual
,
scale_ub
)
scale_ub
)
...
...
tests/kernels/test_gguf.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
import
pytest
import
pytest
import
os
import
os
...
@@ -10,23 +9,37 @@ from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
...
@@ -10,23 +9,37 @@ from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
import
vllm._custom_ops
as
ops
import
vllm._custom_ops
as
ops
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.quantization.gguf
import
_fused_moe_gguf
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
..utils
import
models_path_prefix
from
..utils
import
models_path_prefix
# GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
# GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
# GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
GGUF_SAMPLE
=
os
.
path
.
join
(
models_path_prefix
,
"Isotr0py/test-gguf-sample"
)
GGUF_SAMPLE
=
os
.
path
.
join
(
models_path_prefix
,
"Isotr0py/test-gguf-sample"
)
GGUF_SAMPLE_MOE
=
os
.
path
.
join
(
models_path_prefix
,
"SzymonOzog/test-gguf-moe-sample"
)
def
get_gguf_sample_tensors
(
def
get_gguf_sample_tensors
(
hidden_size
:
int
,
hidden_size
:
int
,
quant_type
:
GGMLQuantizationType
)
->
L
ist
[
ReaderTensor
]:
quant_type
:
GGMLQuantizationType
)
->
l
ist
[
ReaderTensor
]:
sample_dir
=
GGUF_SAMPLE
sample_dir
=
GGUF_SAMPLE
filename
=
f
"Quant_
{
quant_type
.
name
}
_
{
hidden_size
}
.gguf"
filename
=
f
"Quant_
{
quant_type
.
name
}
_
{
hidden_size
}
.gguf"
sample_file
=
Path
(
sample_dir
)
/
filename
sample_file
=
Path
(
sample_dir
)
/
filename
return
GGUFReader
(
sample_file
).
tensors
return
GGUFReader
(
sample_file
).
tensors
DTYPES
=
[
torch
.
half
]
def
get_gguf_MoE_tensors
(
hidden_size
:
int
,
quant_type
:
GGMLQuantizationType
)
->
list
[
ReaderTensor
]:
sample_dir
=
GGUF_SAMPLE_MOE
filename
=
f
"Quant_
{
quant_type
.
name
}
_
{
hidden_size
}
.gguf"
sample_file
=
Path
(
sample_dir
)
/
filename
return
GGUFReader
(
sample_file
).
tensors
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float32
]
# Hidden_size for testing, must match the sample file in HF repo,
# Hidden_size for testing, must match the sample file in HF repo,
# we have `hidden_size = 256, 1024` for test in HF repo currently.
# we have `hidden_size = 256, 1024` for test in HF repo currently.
HIDDEN_SIZES
=
[
256
,
1024
]
HIDDEN_SIZES
=
[
256
,
1024
]
...
@@ -56,7 +69,7 @@ QUANT_TYPES = [
...
@@ -56,7 +69,7 @@ QUANT_TYPES = [
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
]
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
QUANT_TYPES
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
QUANT_TYPES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_dequantize
(
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
def
test_dequantize
(
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
...
@@ -126,7 +139,64 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
...
@@ -126,7 +139,64 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
ref_output
=
x
@
weight
.
T
ref_output
=
x
@
weight
.
T
qweight
=
torch
.
tensor
(
tensor
.
data
,
device
=
"cuda"
)
qweight
=
torch
.
tensor
(
tensor
.
data
,
device
=
"cuda"
)
output
=
ops
.
ggml_mul_mat_a8
(
qweight
,
x
,
quant_type
,
output
=
ops
.
ggml_mul_mat_a8
(
qweight
,
x
,
quant_type
,
qweight
.
shape
[
0
])
qweight
.
shape
[
0
]).
to
(
dtype
)
atols
=
{
torch
.
half
:
1
,
torch
.
bfloat16
:
1.5
,
torch
.
float
:
1.2
}
# test matrix has inputs centered around 0 and lower precision from
# bfloat16 tends to accumulate and can greatly inflate rtol
# since outputs are also very close to 0
rtols
=
{
torch
.
half
:
1e-1
,
torch
.
bfloat16
:
1e4
,
torch
.
float
:
2e1
}
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atols
[
dtype
],
rtol
=
rtols
[
dtype
])
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1
,
rtol
=
1e-1
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
# k-quants
GGMLQuantizationType
.
Q2_K
,
GGMLQuantizationType
.
Q3_K
,
GGMLQuantizationType
.
Q4_K
,
GGMLQuantizationType
.
Q5_K
,
GGMLQuantizationType
.
Q6_K
,
# standard quants
GGMLQuantizationType
.
Q4_0
,
GGMLQuantizationType
.
Q5_0
,
GGMLQuantizationType
.
Q8_0
,
])
@
torch
.
inference_mode
()
def
test_moe
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
quant_type
:
GGMLQuantizationType
,
top_k
:
int
):
current_platform
.
seed_everything
(
0
)
H
,
E
=
1024
,
256
x
=
torch
.
rand
((
num_tokens
,
H
),
dtype
=
dtype
,
device
=
"cuda"
)
topk_weights
=
torch
.
rand
(
num_tokens
,
top_k
,
device
=
"cuda"
,
dtype
=
dtype
)
topk_ids
=
torch
.
randint
(
0
,
E
,
(
num_tokens
,
top_k
),
device
=
"cuda"
)
tensors
=
get_gguf_MoE_tensors
(
hidden_size
,
quant_type
)
w13
=
tensors
[
0
]
w2
=
tensors
[
1
]
w13_dequant
=
torch
.
tensor
(
dequantize
(
w13
.
data
,
quant_type
),
device
=
"cuda"
).
to
(
dtype
)
w2_dequant
=
torch
.
tensor
(
dequantize
(
w2
.
data
,
quant_type
),
device
=
"cuda"
).
to
(
dtype
)
act
=
SiluAndMul
()
output
=
_fused_moe_gguf
(
x
,
torch
.
tensor
(
w13
.
data
,
device
=
"cuda"
),
torch
.
tensor
(
w2
.
data
,
device
=
"cuda"
),
topk_weights
,
topk_ids
,
quant_type
,
quant_type
,
act
)
ref_output
=
fused_experts
(
x
,
w13_dequant
,
w2_dequant
,
topk_weights
,
topk_ids
).
reshape
(
output
.
shape
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
1
,
rtol
=
1e-1
)
tests/kernels/test_machete_mm.py
View file @
469e903b
...
@@ -6,7 +6,7 @@ Run `pytest tests/kernels/test_machete_mm.py`.
...
@@ -6,7 +6,7 @@ Run `pytest tests/kernels/test_machete_mm.py`.
import
math
import
math
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -45,7 +45,7 @@ MNK_SHAPES = [
...
@@ -45,7 +45,7 @@ MNK_SHAPES = [
(
1024
,
8192
,
4096
),
(
1024
,
8192
,
4096
),
]
]
GROUP_SIZES_TO_TEST
:
L
ist
[
Optional
[
int
]]
=
[
128
,
-
1
]
GROUP_SIZES_TO_TEST
:
l
ist
[
Optional
[
int
]]
=
[
128
,
-
1
]
@
dataclass
@
dataclass
...
@@ -75,7 +75,7 @@ class Tensors:
...
@@ -75,7 +75,7 @@ class Tensors:
# Ch Scales Type, Tok Scales Type)
# Ch Scales Type, Tok Scales Type)
# NOTE: None "Scale Type" means the act type is floating point
# NOTE: None "Scale Type" means the act type is floating point
# None "Output Type" means the output type is the same as the act type
# None "Output Type" means the output type is the same as the act type
TestTypeTuple
=
T
uple
[
L
ist
[
torch
.
dtype
],
ScalarType
,
Optional
[
torch
.
dtype
],
TestTypeTuple
=
t
uple
[
l
ist
[
torch
.
dtype
],
ScalarType
,
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
],
bool
]
Optional
[
torch
.
dtype
],
bool
]
TEST_TYPES
=
[
TEST_TYPES
=
[
# GPTQ style
# GPTQ style
...
@@ -136,7 +136,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
...
@@ -136,7 +136,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
return
zps
if
zps
is
None
else
-
1
*
s
*
(
zps
.
to
(
s
.
dtype
))
return
zps
if
zps
is
None
else
-
1
*
s
*
(
zps
.
to
(
s
.
dtype
))
def
group_size_valid
(
shape
:
T
uple
[
int
,
int
,
int
],
def
group_size_valid
(
shape
:
t
uple
[
int
,
int
,
int
],
group_size
:
Optional
[
int
])
->
bool
:
group_size
:
Optional
[
int
])
->
bool
:
return
group_size
is
None
or
group_size
==
-
1
or
group_size
%
shape
[
2
]
==
0
return
group_size
is
None
or
group_size
==
-
1
or
group_size
%
shape
[
2
]
==
0
...
@@ -166,7 +166,7 @@ def machete_quantize_and_pack(atype: torch.dtype,
...
@@ -166,7 +166,7 @@ def machete_quantize_and_pack(atype: torch.dtype,
return
w_ref
,
w_q_machete
,
w_s
,
w_zp
return
w_ref
,
w_q_machete
,
w_s
,
w_zp
def
create_test_tensors
(
shape
:
T
uple
[
int
,
int
,
int
],
def
create_test_tensors
(
shape
:
t
uple
[
int
,
int
,
int
],
types
:
TypeConfig
,
types
:
TypeConfig
,
group_size
:
Optional
[
int
],
group_size
:
Optional
[
int
],
subset_stride_factor
:
Optional
[
int
]
=
None
)
->
Tensors
:
subset_stride_factor
:
Optional
[
int
]
=
None
)
->
Tensors
:
...
@@ -265,7 +265,7 @@ def machete_mm_test_helper(types: TypeConfig,
...
@@ -265,7 +265,7 @@ def machete_mm_test_helper(types: TypeConfig,
@
pytest
.
mark
.
parametrize
(
"types"
,
TEST_TYPES
)
@
pytest
.
mark
.
parametrize
(
"types"
,
TEST_TYPES
)
def
test_machete_all_schedules
(
shape
,
types
:
TypeConfig
):
def
test_machete_all_schedules
(
shape
,
types
:
TypeConfig
):
group_sizes
:
L
ist
[
Optional
[
int
]]
=
[]
group_sizes
:
l
ist
[
Optional
[
int
]]
=
[]
if
types
.
group_scale_type
is
None
:
if
types
.
group_scale_type
is
None
:
group_sizes
=
[
None
]
group_sizes
=
[
None
]
else
:
else
:
...
@@ -294,7 +294,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
...
@@ -294,7 +294,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
ids
=
lambda
x
:
"x"
.
join
(
str
(
v
)
for
v
in
x
))
ids
=
lambda
x
:
"x"
.
join
(
str
(
v
)
for
v
in
x
))
@
pytest
.
mark
.
parametrize
(
"types"
,
TEST_TYPES
)
@
pytest
.
mark
.
parametrize
(
"types"
,
TEST_TYPES
)
def
test_machete_heuristic
(
shape
,
types
:
TypeConfig
):
def
test_machete_heuristic
(
shape
,
types
:
TypeConfig
):
group_sizes
:
L
ist
[
Optional
[
int
]]
=
[]
group_sizes
:
l
ist
[
Optional
[
int
]]
=
[]
if
types
.
group_scale_type
is
None
:
if
types
.
group_scale_type
is
None
:
group_sizes
=
[
None
]
group_sizes
=
[
None
]
else
:
else
:
...
...
tests/kernels/test_mamba_mixer2.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
unittest
import
unittest
from
typing
import
Tuple
import
pytest
import
pytest
import
torch
import
torch
...
@@ -29,7 +28,7 @@ from vllm.utils import update_environment_variables
...
@@ -29,7 +28,7 @@ from vllm.utils import update_environment_variables
def
test_mixer2_gated_norm_multi_gpu
(
def
test_mixer2_gated_norm_multi_gpu
(
batch_size
:
int
,
batch_size
:
int
,
seq_len
:
int
,
seq_len
:
int
,
hidden_size_n_groups
:
T
uple
[
int
,
int
],
hidden_size_n_groups
:
t
uple
[
int
,
int
],
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
str
=
'cuda'
,
device
:
str
=
'cuda'
,
):
):
...
...
tests/kernels/test_mamba_ssm_ssd.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
Tuple
import
pytest
import
pytest
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -134,7 +132,7 @@ def generate_continous_batched_examples(example_lens_by_batch,
...
@@ -134,7 +132,7 @@ def generate_continous_batched_examples(example_lens_by_batch,
# given a tuple of lengths for each example in the batch
# given a tuple of lengths for each example in the batch
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
# 4 examples from second eg, etc
# 4 examples from second eg, etc
def
get_continuous_batch
(
example_lens
:
T
uple
[
int
,
...]):
def
get_continuous_batch
(
example_lens
:
t
uple
[
int
,
...]):
indices
=
[]
indices
=
[]
for
i
,
x
in
enumerate
(
example_lens
):
for
i
,
x
in
enumerate
(
example_lens
):
...
@@ -264,8 +262,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
...
@@ -264,8 +262,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# hold state during the cutting process so we know if an
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
# example has been exhausted and needs to cycle
last_taken
:
D
ict
=
{}
# map: eg -> pointer to last taken sample
last_taken
:
d
ict
=
{}
# map: eg -> pointer to last taken sample
exhausted
:
D
ict
=
{}
# map: eg -> boolean indicating example is exhausted
exhausted
:
d
ict
=
{}
# map: eg -> boolean indicating example is exhausted
states
=
None
states
=
None
for
Y_min
,
cu_seqlens
,
sed_idx
,
(
A
,
dt
,
X
,
B
,
for
Y_min
,
cu_seqlens
,
sed_idx
,
(
A
,
dt
,
X
,
B
,
...
...
tests/kernels/test_moe.py
View file @
469e903b
...
@@ -3,8 +3,11 @@
...
@@ -3,8 +3,11 @@
Run `pytest tests/kernels/test_moe.py`.
Run `pytest tests/kernels/test_moe.py`.
"""
"""
import
pytest
import
pytest
import
torch
import
torch
from
torch.nn
import
Parameter
from
torch.nn
import
functional
as
F
from
transformers
import
MixtralConfig
from
transformers
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
...
@@ -26,6 +29,7 @@ from vllm.platforms import current_platform
...
@@ -26,6 +29,7 @@ from vllm.platforms import current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
NUM_EXPERTS
=
[
8
,
64
]
NUM_EXPERTS
=
[
8
,
64
]
EP_SIZE
=
[
1
,
4
]
TOP_KS
=
[
2
,
6
]
TOP_KS
=
[
2
,
6
]
...
@@ -34,24 +38,64 @@ TOP_KS = [2, 6]
...
@@ -34,24 +38,64 @@ TOP_KS = [2, 6]
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
511
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
True
,
False
])
def
test_fused_moe
(
def
test_fused_moe
(
m
:
int
,
m
:
int
,
n
:
int
,
n
:
int
,
k
:
int
,
k
:
int
,
e
:
int
,
e
:
int
,
topk
:
int
,
topk
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
padding
:
bool
,
):
):
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
if
ep_size
>
1
:
local_e
=
e
//
ep_size
e_ids
=
torch
.
randint
(
0
,
e
,
(
local_e
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
=
torch
.
full
((
e
,
),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
[
e_ids
]
=
torch
.
arange
(
local_e
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
w1
=
w1
[
e_ids
]
w2
=
w2
[
e_ids
]
else
:
e_map
=
None
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
,
e_map
)
iterative_output
=
iterative_moe
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
# Pad the weight if moe padding is enabled
if
padding
:
w1
=
F
.
pad
(
w1
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
torch
.
cuda
.
empty_cache
()
w2
=
F
.
pad
(
w2
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
]
torch
.
cuda
.
empty_cache
()
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
global_num_experts
=
e
,
expert_map
=
e_map
,
renormalize
=
False
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
iterative_output
=
iterative_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch
.
testing
.
assert_close
(
iterative_output
,
torch
.
testing
.
assert_close
(
iterative_output
,
torch_output
,
torch_output
,
atol
=
2e-2
,
atol
=
2e-2
,
...
@@ -63,13 +107,14 @@ def test_fused_moe(
...
@@ -63,13 +107,14 @@ def test_fused_moe(
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"has_zp"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"has_zp"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"weight_bits"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"weight_bits"
,
[
4
,
8
])
def
test_fused_moe_wn16
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
def
test_fused_moe_wn16
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
has_zp
:
bool
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
weight_bits
:
int
):
has_zp
:
bool
,
weight_bits
:
int
):
print
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
group_size
,
has_zp
,
weight_bits
)
print
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
group_size
,
has_zp
,
weight_bits
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
...
@@ -130,6 +175,25 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
...
@@ -130,6 +175,25 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
if
has_zp
:
if
has_zp
:
w_qzeros
[
expert_id
]
=
qzeros
w_qzeros
[
expert_id
]
=
qzeros
if
ep_size
>
1
:
local_e
=
e
//
ep_size
e_ids
=
torch
.
randint
(
0
,
e
,
(
local_e
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
=
torch
.
full
((
e
,
),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
e_map
[
e_ids
]
=
torch
.
arange
(
local_e
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
w1_ref
=
w1_ref
[
e_ids
]
w2_ref
=
w2_ref
[
e_ids
]
w1_qweight
=
w1_qweight
[
e_ids
]
w2_qweight
=
w2_qweight
[
e_ids
]
w1_scales
=
w1_scales
[
e_ids
]
w2_scales
=
w2_scales
[
e_ids
]
w1_qzeros
=
w1_qzeros
[
e_ids
]
w2_qzeros
=
w2_qzeros
[
e_ids
]
else
:
e_map
=
None
triton_output
=
fused_moe
(
a
,
triton_output
=
fused_moe
(
a
,
w1_qweight
,
w1_qweight
,
w2_qweight
,
w2_qweight
,
...
@@ -138,19 +202,22 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
...
@@ -138,19 +202,22 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
renormalize
=
False
,
renormalize
=
False
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
use_int8_w8a16
=
weight_bits
==
8
,
global_num_experts
=
e
,
expert_map
=
e_map
,
w1_scale
=
w1_scales
,
w1_scale
=
w1_scales
,
w2_scale
=
w2_scales
,
w2_scale
=
w2_scales
,
w1_zp
=
w1_qzeros
if
has_zp
else
None
,
w1_zp
=
w1_qzeros
if
has_zp
else
None
,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
w2_zp
=
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
group_size
])
block_shape
=
[
0
,
group_size
])
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
)
torch_output
=
torch_moe
(
a
,
w1_ref
,
w2_ref
,
score
,
topk
,
e_map
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
True
,
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_mixtral_moe
(
dtype
:
torch
.
dtype
):
def
test_mixtral_moe
(
dtype
:
torch
.
dtype
,
padding
:
bool
):
"""Make sure our Mixtral MoE implementation agrees with the one from
"""Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""
huggingface."""
...
@@ -164,6 +231,7 @@ def test_mixtral_moe(dtype: torch.dtype):
...
@@ -164,6 +231,7 @@ def test_mixtral_moe(dtype: torch.dtype):
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
tp_size
=
1
,
tp_size
=
1
,
dp_size
=
1
,
).
cuda
()
).
cuda
()
# Load the weights
# Load the weights
...
@@ -179,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype):
...
@@ -179,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype):
# vLLM uses 1D query [num_tokens, hidden_dim]
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs
=
hf_inputs
.
flatten
(
0
,
1
)
vllm_inputs
=
hf_inputs
.
flatten
(
0
,
1
)
# Pad the weight if moe padding is enabled
if
padding
:
vllm_moe
.
experts
.
w13_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w13_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
vllm_moe
.
experts
.
w2_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w2_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
# Run forward passes for both MoE blocks
# Run forward passes for both MoE blocks
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
vllm_states
=
vllm_moe
.
forward
(
vllm_inputs
)
vllm_states
=
vllm_moe
.
forward
(
vllm_inputs
)
...
...
tests/kernels/test_nvfp4_scaled_mm.py
0 → 100644
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
if
not
current_platform
.
has_device_capability
(
100
):
pytest
.
skip
(
reason
=
"Nvfp4 Requires compute capability of 10 or above."
,
allow_module_level
=
True
)
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
# m, n, k
SHAPES
=
[(
128
,
128
,
64
),
(
128
,
128
,
128
),
(
256
,
128
,
64
),
(
128
,
256
,
128
)]
PAD_SHAPES
=
[(
150
,
128
,
64
),
(
128
,
128
,
96
)]
SHAPES
.
extend
(
PAD_SHAPES
)
SEEDS
=
[
42
]
CUDA_DEVICES
=
[
'cuda:0'
]
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1fn
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
kE2M1ToFloatArray
=
[
0.
,
0.5
,
1.
,
1.5
,
2.
,
3.
,
4.
,
6.
,
]
def
e2m1_to_fp32
(
int4_value
):
signBit
=
(
int4_value
&
0x8
)
int4_absValue
=
int4_value
&
0x7
float_result
=
kE2M1ToFloatArray
[
int4_absValue
]
if
(
signBit
):
float_result
=
-
float_result
return
float_result
def
break_fp4_bytes
(
a
,
dtype
):
assert
(
a
.
dtype
==
torch
.
uint8
)
m
,
n
=
a
.
shape
a
=
a
.
flatten
()
# Get upper 4 bits
highHalfByte
=
(
a
&
0xF0
)
>>
4
# Get lower 4 bits
lowHalfByte
=
a
&
0x0F
fH
=
torch
.
tensor
([
e2m1_to_fp32
(
x
)
for
x
in
highHalfByte
]).
to
(
a
.
device
)
fL
=
torch
.
tensor
([
e2m1_to_fp32
(
x
)
for
x
in
lowHalfByte
]).
to
(
a
.
device
)
# [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC]
out
=
torch
.
stack
((
fL
,
fH
),
dim
=-
1
).
reshape
(
m
,
n
*
2
)
return
out
def
convert_swizzled_to_linear
(
a_sf_swizzled
:
torch
.
Tensor
,
m
,
k
,
block_size
):
sf_m
,
sf_k
=
a_sf_swizzled
.
shape
m_tiles
=
(
m
+
128
-
1
)
//
128
f
=
block_size
*
4
k_tiles
=
(
k
+
f
-
1
)
//
f
tmp
=
torch
.
reshape
(
a_sf_swizzled
,
(
1
,
m_tiles
,
k_tiles
,
32
,
4
,
4
))
tmp
=
torch
.
permute
(
tmp
,
(
0
,
1
,
4
,
3
,
2
,
5
))
out
=
tmp
.
reshape
(
m_tiles
*
128
,
k_tiles
*
f
//
block_size
)
return
out
[
0
:
m
,
0
:
k
]
def
dequantize_to_dtype
(
tensor_fp4
,
tensor_sf
,
global_scale
,
dtype
,
device
,
block_size
=
16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert
tensor_fp4
.
dtype
==
torch
.
uint8
m
,
packed_k
=
tensor_fp4
.
shape
k
=
packed_k
*
2
tensor_f32
=
break_fp4_bytes
(
tensor_fp4
,
dtype
)
tensor_f32
=
tensor_f32
.
reshape
(
m
,
k
//
block_size
,
block_size
)
tensor_sf
=
tensor_sf
.
view
(
torch
.
float8_e4m3fn
)
tensor_sf
=
convert_swizzled_to_linear
(
tensor_sf
,
m
,
k
,
block_size
)
tensor_sf_dtype
=
tensor_sf
.
to
(
torch
.
float32
)
/
global_scale
# scale the tensor
out
=
(
tensor_f32
*
tensor_sf_dtype
.
unsqueeze
(
-
1
)).
reshape
(
m
,
k
)
return
out
def
get_ref_results
(
a_fp4
,
b_fp4
,
a_sf
,
b_sf
,
a_global_scale
,
b_global_scale
,
m
,
n
,
dtype
,
block_size
,
device
):
_
,
m_k
=
a_fp4
.
shape
_
,
n_k
=
b_fp4
.
shape
assert
(
m_k
==
n_k
)
a_in_dtype
=
dequantize_to_dtype
(
a_fp4
,
a_sf
,
a_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
b_in_dtype
=
dequantize_to_dtype
(
b_fp4
,
b_sf
,
b_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
return
torch
.
matmul
(
a_in_dtype
,
b_in_dtype
.
t
())
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
SHAPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_nvfp4_gemm
(
dtype
:
torch
.
dtype
,
shape
:
tuple
[
int
,
int
,
int
],
seed
:
int
,
device
:
str
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
m
,
n
,
packed_k
=
shape
k
=
packed_k
*
2
block_size
=
16
a_dtype
=
torch
.
randn
((
m
,
k
),
dtype
=
dtype
,
device
=
device
)
b_dtype
=
torch
.
randn
((
n
,
k
),
dtype
=
dtype
,
device
=
device
)
a_global_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a_dtype
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
b_global_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
b_dtype
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
alpha
=
1.
/
(
a_global_scale
*
b_global_scale
)
a_fp4
,
a_scale_interleaved
=
ops
.
scaled_fp4_quant
(
a_dtype
,
a_global_scale
)
b_fp4
,
b_scale_interleaved
=
ops
.
scaled_fp4_quant
(
b_dtype
,
b_global_scale
)
expected_out
=
get_ref_results
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
a_global_scale
,
b_global_scale
,
m
,
n
,
dtype
,
block_size
,
device
)
out
=
ops
.
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
dtype
)
torch
.
testing
.
assert_close
(
out
,
expected_out
.
to
(
dtype
=
dtype
),
atol
=
1e-1
,
rtol
=
1e-1
)
tests/kernels/test_pos_encoding.py
View file @
469e903b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
itertools
import
accumulate
,
product
from
itertools
import
accumulate
,
product
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Callable
,
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -179,7 +179,7 @@ def test_batched_rotary_embedding_multi_lora(
...
@@ -179,7 +179,7 @@ def test_batched_rotary_embedding_multi_lora(
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
if
rotary_dim
is
None
:
if
rotary_dim
is
None
:
rotary_dim
=
head_size
rotary_dim
=
head_size
scaling_factors
:
L
ist
[
int
]
=
[
1
,
2
,
4
]
scaling_factors
:
l
ist
[
int
]
=
[
1
,
2
,
4
]
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"rope_type"
:
"linear"
,
"factor"
:
tuple
(
scaling_factors
)
"factor"
:
tuple
(
scaling_factors
)
...
@@ -234,7 +234,7 @@ def test_rope_module_cache():
...
@@ -234,7 +234,7 @@ def test_rope_module_cache():
})
})
settings
=
(
HEAD_SIZES
,
ROTARY_DIMS
,
MAX_POSITIONS
,
BASES
,
IS_NEOX_STYLE
,
settings
=
(
HEAD_SIZES
,
ROTARY_DIMS
,
MAX_POSITIONS
,
BASES
,
IS_NEOX_STYLE
,
ROPE_SCALINGS
,
DTYPES
)
ROPE_SCALINGS
,
DTYPES
)
rope_setting_id_map
:
D
ict
[
str
,
int
]
=
{}
rope_setting_id_map
:
d
ict
[
str
,
int
]
=
{}
for
setting
in
product
(
*
settings
):
for
setting
in
product
(
*
settings
):
head_size
,
rotary_dim
,
max_position
,
base
,
\
head_size
,
rotary_dim
,
max_position
,
base
,
\
is_neox_stype
,
rope_scaling
,
dtype
=
setting
is_neox_stype
,
rope_scaling
,
dtype
=
setting
...
...
Prev
1
…
16
17
18
19
20
21
22
23
24
…
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment