Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
217ee621
Commit
217ee621
authored
Dec 05, 2024
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.6.2-dev' into v0.6.2-dev
parents
f0021a4d
3f78216a
Changes
68
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
782 additions
and
528 deletions
+782
-528
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+3
-5
tests/kernels/test_awq_triton.py
tests/kernels/test_awq_triton.py
+6
-1
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+254
-148
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+259
-230
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+2
-2
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+35
-31
tests/kernels/test_int8_quant.py
tests/kernels/test_int8_quant.py
+81
-45
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+21
-11
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+43
-21
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+64
-28
tests/kernels/test_rotary_embedding.py
tests/kernels/test_rotary_embedding.py
+6
-1
tests/kernels/test_utils.py
tests/kernels/test_utils.py
+8
-5
tests/kernels/untest_aqlm.py
tests/kernels/untest_aqlm.py
+0
-0
tests/kernels/untest_awq.py
tests/kernels/untest_awq.py
+0
-0
tests/kernels/untest_causal_conv1d.py
tests/kernels/untest_causal_conv1d.py
+0
-0
tests/kernels/untest_flashinfer.py
tests/kernels/untest_flashinfer.py
+0
-0
tests/kernels/untest_fp8_quant.py
tests/kernels/untest_fp8_quant.py
+0
-0
tests/kernels/untest_ggml.py
tests/kernels/untest_ggml.py
+0
-0
tests/kernels/untest_gguf.py
tests/kernels/untest_gguf.py
+0
-0
tests/kernels/untest_gptq.py
tests/kernels/untest_gptq.py
+0
-0
No files found.
tests/kernels/test_attention_selector.py
View file @
217ee621
...
...
@@ -6,14 +6,12 @@ import torch
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm.attention.selector
import
which_attn_to_use
from
vllm.utils
import
STR_FLASH_ATTN_VAL
,
STR_INVALID_VAL
from
vllm.utils
import
is_hip
# @pytest.mark.parametrize(
# "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
# @pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
@
pytest
.
mark
.
parametrize
(
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
,
"OPENVINO"
]
if
not
is_hip
()
else
[
"ROCM_FLASH"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"openvino"
,
"hip"
,
"cuda"
])
def
test_env
(
name
:
str
,
device
:
str
,
monkeypatch
):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
...
...
tests/kernels/test_awq_triton.py
View file @
217ee621
...
...
@@ -8,8 +8,9 @@ import torch
from
vllm.model_executor.layers.quantization.awq_triton
import
(
AWQ_TRITON_SUPPORTED_GROUP_SIZES
,
awq_dequantize_triton
,
awq_gemm_triton
)
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
device
=
"cuda"
device
=
"cuda"
def
reverse_awq_order
(
t
:
torch
.
Tensor
):
...
...
@@ -64,6 +65,8 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
# qweights - [R , C // 8], int32
# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
@
pytest
.
mark
.
skipif
(
torch_version
.
startswith
(
"2.3"
),
reason
=
"Need triton3.0."
)
@
pytest
.
mark
.
parametrize
(
"qweight_rows"
,
[
3584
,
18944
,
128
,
256
,
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"qweight_cols"
,
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
AWQ_TRITON_SUPPORTED_GROUP_SIZES
)
...
...
@@ -111,6 +114,8 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
# qweight - [K, M // 8]
# qzeros - [K // G, M // 8]
# scales - [K // G, M]
@
pytest
.
mark
.
skipif
(
torch_version
.
startswith
(
"2.3"
),
reason
=
"Need triton3.0."
)
@
pytest
.
mark
.
parametrize
(
"N"
,
[
1
,
2
,
4
,
8
,
14
,
17
,
23
,
32
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"M"
,
[
16
,
24
,
32
])
...
...
tests/kernels/test_cache.py
View file @
217ee621
...
...
@@ -4,9 +4,11 @@ from typing import List, Tuple
import
pytest
import
torch
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
,
opcheck
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
seed_everything
from
vllm.utils
import
is_hip
from
.utils
import
torch_version
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
...
@@ -87,26 +89,45 @@ def test_copy_blocks(
block_mapping_tensor
=
torch
.
tensor
(
block_mapping
,
dtype
=
torch
.
int64
,
device
=
device
).
view
(
-
1
,
2
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
copy_blocks
,
(
key_caches
,
value_caches
,
block_mapping_tensor
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping_tensor
)
# Run the reference implementation.
for
src
,
dst
in
block_mapping
:
for
cloned_key_cache
in
cloned_key_caches
:
cloned_key_cache
[
dst
].
copy_
(
cloned_key_cache
[
src
])
for
cloned_value_cache
in
cloned_value_caches
:
cloned_value_cache
[
dst
].
copy_
(
cloned_value_cache
[
src
])
# Compare the results.
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
cloned_value_caches
):
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
if
torch_version
.
startswith
(
"2.3"
):
ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping_tensor
)
for
src
,
dst
in
block_mapping
:
for
cloned_key_cache
in
cloned_key_caches
:
cloned_key_cache
[
dst
].
copy_
(
cloned_key_cache
[
src
])
for
cloned_value_cache
in
cloned_value_caches
:
cloned_value_cache
[
dst
].
copy_
(
cloned_value_cache
[
src
])
# Compare the results.
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
cloned_value_caches
):
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
_C_cache_ops
.
copy_blocks
,
(
key_caches
,
value_caches
,
block_mapping_tensor
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping_tensor
)
# Run the reference implementation.
for
src
,
dst
in
block_mapping
:
for
cloned_key_cache
in
cloned_key_caches
:
cloned_key_cache
[
dst
].
copy_
(
cloned_key_cache
[
src
])
for
cloned_value_cache
in
cloned_value_caches
:
cloned_value_cache
[
dst
].
copy_
(
cloned_value_cache
[
src
])
# Compare the results.
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
cloned_value_caches
):
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
...
@@ -162,46 +183,87 @@ def test_reshape_and_cache(
# Using default kv_scale
k_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indicies_lst
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
torch
.
testing
.
assert_close
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
torch
.
testing
.
assert_close
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
if
torch_version
.
startswith
(
"2.3"
):
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indicies
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies
[
i
]
block_offset
=
block_offsets
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
assert
torch
.
allclose
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
assert
torch
.
allclose
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
else
:
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indicies_lst
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
torch
.
testing
.
assert_close
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
torch
.
testing
.
assert_close
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
else
:
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
else
:
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
...
...
@@ -272,43 +334,69 @@ def test_reshape_and_cache_flash(
# Using default kv_scale
k_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
# Run the reference implementation.
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indicies_lst
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
if
torch_version
.
startswith
(
"2.3"
):
# Clone the KV caches.
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
if
kv_cache_dtype
==
"fp8"
:
torch
.
testing
.
assert_close
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
torch
.
testing
.
assert_close
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
# Call the reshape_and_cache kernel.
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
# Run the reference implementation.
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
'floor'
)
block_indicies
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies
[
i
]
block_offset
=
block_offsets
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
# Run the reference implementation.
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indicies_lst
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
torch
.
testing
.
assert_close
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
torch
.
testing
.
assert_close
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
else
:
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
else
:
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
...
...
@@ -371,56 +459,74 @@ def test_swap_blocks(
src_key_caches_clone
=
src_key_caches
[
0
].
clone
()
src_value_caches_clone
=
src_value_caches
[
0
].
clone
()
# Call the swap_blocks kernel.
do_opcheck
=
(
head_size
==
HEAD_SIZES
[
0
])
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping_tensor
),
cond
=
do_opcheck
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping_tensor
),
cond
=
do_opcheck
)
ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping_tensor
)
ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping_tensor
)
for
src
,
dst
in
block_mapping
:
torch
.
testing
.
assert_close
(
src_key_caches_clone
[
src
].
cpu
(),
dist_key_caches
[
0
][
dst
].
cpu
())
torch
.
testing
.
assert_close
(
src_value_caches_clone
[
src
].
cpu
(),
dist_value_caches
[
0
][
dst
].
cpu
())
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("block_size", BLOCK_SIZES)
# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @torch.inference_mode()
# def test_fp8_e4m3_conversion(
# num_heads: int,
# head_size: int,
# block_size: int,
# num_blocks: int,
# dtype: torch.dtype,
# seed: int,
# device: str,
# ) -> None:
# seed_everything(seed)
# low = -224.0
# high = 224.0
# shape = (num_blocks, num_heads, head_size, block_size)
# cache = torch.empty(shape, dtype=dtype, device=device)
# cache.uniform_(low, high)
# cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
# ops.convert_fp8(cache_fp8, cache)
# converted_cache = torch.empty_like(cache)
# ops.convert_fp8(converted_cache, cache_fp8)
# torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
if
torch_version
.
startswith
(
"2.3"
):
# Call the swap_blocks kernel.
ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping_tensor
)
ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping_tensor
)
for
src
,
dst
in
block_mapping
:
assert
torch
.
allclose
(
src_key_caches_clone
[
src
].
cpu
(),
dist_key_caches
[
0
][
dst
].
cpu
())
assert
torch
.
allclose
(
src_value_caches_clone
[
src
].
cpu
(),
dist_value_caches
[
0
][
dst
].
cpu
())
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
# Call the swap_blocks kernel.
do_opcheck
=
(
head_size
==
HEAD_SIZES
[
0
])
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping_tensor
),
cond
=
do_opcheck
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping_tensor
),
cond
=
do_opcheck
)
ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping_tensor
)
ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping_tensor
)
for
src
,
dst
in
block_mapping
:
torch
.
testing
.
assert_close
(
src_key_caches_clone
[
src
].
cpu
(),
dist_key_caches
[
0
][
dst
].
cpu
())
torch
.
testing
.
assert_close
(
src_value_caches_clone
[
src
].
cpu
(),
dist_value_caches
[
0
][
dst
].
cpu
())
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"FP8 is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_fp8_e4m3_conversion
(
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
)
->
None
:
seed_everything
(
seed
)
low
=
-
224.0
high
=
224.0
shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
cache
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
)
cache
.
uniform_
(
low
,
high
)
cache_fp8
=
torch
.
empty_like
(
cache
,
dtype
=
torch
.
uint8
)
ops
.
convert_fp8
(
cache_fp8
,
cache
)
converted_cache
=
torch
.
empty_like
(
cache
)
ops
.
convert_fp8
(
converted_cache
,
cache_fp8
)
torch
.
testing
.
assert_close
(
cache
,
converted_cache
,
atol
=
0.001
,
rtol
=
0.1
)
tests/kernels/test_cutlass.py
View file @
217ee621
This diff is collapsed.
Click to expand it.
tests/kernels/test_encoder_decoder_attn.py
View file @
217ee621
...
...
@@ -751,7 +751,7 @@ def test_encoder_only(
No KV cache is required for encoder-only attention.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPU
s, therefore this test simply is skipped if is_hip().
hcu
s, therefore this test simply is skipped if is_hip().
This test globally forces an override of the usual backend
auto-selection process, forcing the specific backend-under-test
...
...
@@ -860,7 +860,7 @@ def test_e2e_enc_dec_attn(
to be utilized.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPU
s, therefore this test simply is skipped if is_hip().
hcu
s, therefore this test simply is skipped if is_hip().
Note on metadata: there is a single attention metadata structure shared by
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
...
...
tests/kernels/test_flash_attn.py
View file @
217ee621
...
...
@@ -8,8 +8,8 @@ if is_hip():
import
flash_attn
else
:
import
vllm.attention.backends.flash_attn
# noqa: F401
from
tests.kernels.utils
import
opcheck
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
),
(
16
,
2
)]
HEAD_SIZES
=
[
128
,
256
]
...
...
@@ -132,19 +132,21 @@ if not is_hip():
else
:
test_utils
=
[
"test_faketensor"
]
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
,
args
=
tuple
(),
kwargs
=
dict
(
decode_query
=
query
.
unsqueeze
(
1
),
key_cache
=
key_cache
,
value_cache
=
value_cache
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
),
test_utils
=
test_utils
)
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
,
args
=
tuple
(),
kwargs
=
dict
(
decode_query
=
query
.
unsqueeze
(
1
),
key_cache
=
key_cache
,
value_cache
=
value_cache
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
),
test_utils
=
test_utils
)
ref_output
=
ref_paged_attn
(
query
=
query
,
...
...
@@ -253,23 +255,25 @@ def test_varlen_with_paged_kv(
test_utils
=
[
"test_faketensor"
]
if
not
is_hip
():
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_varlen_func
,
args
=
tuple
(),
kwargs
=
dict
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
),
test_utils
=
test_utils
)
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_varlen_func
,
args
=
tuple
(),
kwargs
=
dict
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
),
test_utils
=
test_utils
)
ref_output
=
ref_paged_attn
(
query
=
query
,
...
...
tests/kernels/test_int8_quant.py
View file @
217ee621
...
...
@@ -2,9 +2,10 @@ import pytest
import
torch
from
tests.kernels.quant_utils
import
ref_dynamic_per_token_quant
from
tests.kernels.utils
import
opcheck
from
vllm._custom_ops
import
scaled_int8_quant
from
vllm.utils
import
seed_everything
from
vllm.utils
import
is_hip
from
.utils
import
torch_version
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HIDDEN_SIZES
=
[
16
,
67
,
768
,
2048
,
5120
,
5137
,
8192
,
...
...
@@ -14,30 +15,35 @@ SEEDS = [0]
SCALE
=
[
0.1
,
0.5
,
0.8
,
1.2
,
2.1
]
def
opcheck_int8_quant_static
(
output
,
input
,
scale
,
azp
=
None
):
if
azp
is
None
:
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
(
output
,
input
,
scale
,
None
))
else
:
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
(
output
,
input
,
scale
,
azp
))
def
opcheck_int8_quant_dynamic
(
output
,
input
,
symmetric
=
True
):
scale
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
if
symmetric
:
opcheck
(
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
,
(
output
,
input
,
scale
,
None
))
else
:
azp
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
int32
)
opcheck
(
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
,
(
output
,
input
,
scale
,
azp
))
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
def
opcheck_int8_quant_static
(
output
,
input
,
scale
,
azp
=
None
):
if
azp
is
None
:
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
(
output
,
input
,
scale
,
None
))
else
:
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
(
output
,
input
,
scale
,
azp
))
def
opcheck_int8_quant_dynamic
(
output
,
input
,
symmetric
=
True
):
scale
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
if
symmetric
:
opcheck
(
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
,
(
output
,
input
,
scale
,
None
))
else
:
azp
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
int32
)
opcheck
(
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
,
(
output
,
input
,
scale
,
azp
))
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Currently, there is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
...
...
@@ -54,13 +60,21 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
# kernel
ops_out
,
ops_scales
,
_
=
scaled_int8_quant
(
x
)
torch
.
testing
.
assert_close
(
ops_scales
,
ref_scales
)
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
)
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
ops_scales
,
ref_scales
)
torch
.
allclose
(
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
ops_scales
,
ref_scales
)
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Currently, there is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
...
...
@@ -94,13 +108,20 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
if
(
not
torch
.
allclose
(
scales_out
,
scales
)):
print
(
torch
.
argmax
(
torch
.
abs
(
scales_out
-
scales
)))
torch
.
testing
.
assert_close
(
scales_out
,
scales
)
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
azp_out
,
azps
,
atol
=
1
,
rtol
=
0.0
)
# if AZP is off by 1, after rounding-to-even, the output may be off by 2
torch
.
testing
.
assert_close
(
ops_out
,
torch_out
,
atol
=
2
,
rtol
=
0.0
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
,
False
)
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
scales_out
,
scales
)
torch
.
allclose
(
azp_out
,
azps
,
atol
=
1
,
rtol
=
0.0
)
torch
.
allclose
(
ops_out
,
torch_out
,
atol
=
2
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
scales_out
,
scales
)
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
azp_out
,
azps
,
atol
=
1
,
rtol
=
0.0
)
# if AZP is off by 1, after rounding-to-even, the output may be off by 2
torch
.
testing
.
assert_close
(
ops_out
,
torch_out
,
atol
=
2
,
rtol
=
0.0
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
,
False
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
...
@@ -122,10 +143,15 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits
.
max
).
to
(
torch
.
int8
)
out2
,
_
,
_
=
scaled_int8_quant
(
x
,
scale_arg
)
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
...
@@ -152,11 +178,16 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out2
,
x
,
scale_arg
,
azp_arg
)
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
,
azp_arg
)
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
,
azp_arg
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"is_max"
,
[
True
,
False
])
@
torch
.
inference_mode
()
...
...
@@ -187,4 +218,9 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
out
=
torch
.
empty_like
(
expected
)
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out
,
x
,
scale
,
azp
)
torch
.
testing
.
assert_close
(
expected
,
out
,
atol
=
0
,
rtol
=
0
)
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
expected
,
out
,
atol
=
0
,
rtol
=
0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
expected
,
out
,
atol
=
0
,
rtol
=
0
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
tests/kernels/test_layernorm.py
View file @
217ee621
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
4096
]
# Arbitrary values for testing
...
...
@@ -47,15 +47,25 @@ def test_rms_norm(
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.
# Therefore, we use a larger tolerance.
if
add_residual
:
torch
.
testing
.
assert_close
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
torch_version
.
startswith
(
"2.3"
):
if
add_residual
:
torch
.
allclose
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
allclose
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
if
add_residual
:
torch
.
testing
.
assert_close
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
residual
is
not
None
:
opcheck
(
torch
.
ops
.
_C
.
fused_add_rms_norm
,
(
x
,
residual
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
if
residual
is
not
None
:
opcheck
(
torch
.
ops
.
_C
.
fused_add_rms_norm
,
(
x
,
residual
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
else
:
opcheck
(
torch
.
ops
.
_C
.
rms_norm
,
(
out
,
x
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
else
:
opcheck
(
torch
.
ops
.
_C
.
rms_norm
,
(
out
,
x
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
tests/kernels/test_moe.py
View file @
217ee621
...
...
@@ -9,7 +9,6 @@ import torch
from
transformers
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
...
...
@@ -22,6 +21,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
from
vllm.utils
import
is_hip
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
...
...
@@ -76,7 +77,12 @@ def test_fused_moe(
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
...
...
@@ -120,11 +126,18 @@ def test_mixtral_moe(dtype: torch.dtype):
torch
.
float16
:
1e-3
,
torch
.
bfloat16
:
1e-2
,
}
torch
.
testing
.
assert_close
(
hf_states
.
flatten
(
0
,
1
),
vllm_states
,
rtol
=
mixtral_moe_tol
[
dtype
],
atol
=
mixtral_moe_tol
[
dtype
])
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
hf_states
.
flatten
(
0
,
1
),
vllm_states
,
rtol
=
mixtral_moe_tol
[
dtype
],
atol
=
mixtral_moe_tol
[
dtype
])
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
hf_states
.
flatten
(
0
,
1
),
vllm_states
,
rtol
=
mixtral_moe_tol
[
dtype
],
atol
=
mixtral_moe_tol
[
dtype
])
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
def
stack_and_dev
(
tensors
:
List
[
torch
.
Tensor
]):
...
...
@@ -137,6 +150,8 @@ def compute_max_diff(output, output_ref):
torch
.
abs
(
output_ref
))
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Currently, there is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
,
512
])
...
...
@@ -256,12 +271,14 @@ def test_fused_marlin_moe(
dtype
=
torch
.
int32
,
device
=
a
.
device
)
opcheck
(
torch
.
ops
.
_moe_C
.
topk_softmax
,
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
score
.
float
(),
))
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
_moe_C
.
topk_softmax
,
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
score
.
float
(),
))
block_size_m
=
4
...
...
@@ -274,12 +291,16 @@ def test_fused_marlin_moe(
device
=
"cuda"
,
requires_grad
=
False
)
opcheck
(
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
,
(
a
,
qweight1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales1
,
g_idx1
,
sort_indices1
,
workspace
,
quant_type
,
m
,
2
*
n
,
k
,
True
,
e
,
topk
,
block_size_m
,
True
,
False
))
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
,
(
a
,
qweight1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales1
,
g_idx1
,
sort_indices1
,
workspace
,
quant_type
,
m
,
2
*
n
,
k
,
True
,
e
,
topk
,
block_size_m
,
True
,
False
))
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Currently, there is not supported on ROCm."
)
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
"don't run it in automated tests."
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
...
...
@@ -373,7 +394,8 @@ def test_moe_align_block_size_opcheck():
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
opcheck
(
torch
.
ops
.
_C
.
moe_align_block_size
,
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
))
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
_C
.
moe_align_block_size
,
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
))
tests/kernels/test_pos_encoding.py
View file @
217ee621
...
...
@@ -8,6 +8,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.utils
import
seed_everything
from
.allclose_default
import
get_default_atol
,
get_default_rtol
from
.utils
import
torch_version
IS_NEOX_STYLE
=
[
True
,
False
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
...
@@ -18,7 +19,7 @@ BATCH_SIZES = [1, 5] # Arbitrary values for testing
SEQ_LENS
=
[
11
,
8192
]
# Arbitrary values for testing
SEEDS
=
[
0
]
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
1
)
]
...
...
@@ -67,14 +68,26 @@ def test_rotary_embedding(
ref_query
,
ref_key
=
rope
.
forward_native
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
)
# Compare the results.
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
...
@@ -126,15 +139,27 @@ def test_batched_rotary_embedding(
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
dtype
=
torch
.
long
,
device
=
device
))
# Compare the results.
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
elif
torch_version
.
startswith
(
"2.4"
):
# Compare the results.
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
...
@@ -195,16 +220,27 @@ def test_batched_rotary_embedding_multi_lora(
query_offsets
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
,
query_offsets
.
flatten
())
# Compare the results.
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
elif
torch_version
.
startswith
(
"2.4"
):
# Compare the results.
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
torch
.
inference_mode
()
def
test_rope_module_cache
():
...
...
tests/kernels/test_rotary_embedding.py
View file @
217ee621
...
...
@@ -7,8 +7,11 @@ from typing import Optional
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
.utils
import
torch_version
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
def
rotary_embedding_opcheck
(
rot
,
...
...
@@ -30,6 +33,8 @@ def rotary_embedding_opcheck(rot,
rot
.
is_neox_style
))
@
pytest
.
mark
.
skipif
(
torch_version
.
startswith
(
"2.3"
),
reason
=
"Need torch2.4."
)
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"max_position"
,
[
11
,
4096
,
32768
])
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
[
True
,
False
])
...
...
tests/kernels/test_utils.py
View file @
217ee621
...
...
@@ -5,14 +5,17 @@ Tests for miscellaneous utilities
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.platforms
import
current_platform
from
.utils
import
torch_version
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
def
test_convert_fp8_opcheck
():
data
=
torch
.
randn
((
256
,
256
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
result
=
torch
.
empty_like
(
data
,
dtype
=
torch
.
float8_e4m3fn
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
convert_fp8
,
(
result
,
data
,
1.0
,
"fp8"
))
# def test_convert_fp8_opcheck():
# data = torch.randn((256, 256), dtype=torch.float32, device="cuda")
# result = torch.empty_like(data, dtype=torch.float8_e4m3fn)
# opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
...
...
tests/kernels/test_aqlm.py
→
tests/kernels/
un
test_aqlm.py
View file @
217ee621
File moved
tests/kernels/test_awq.py
→
tests/kernels/
un
test_awq.py
View file @
217ee621
File moved
tests/kernels/test_causal_conv1d.py
→
tests/kernels/
un
test_causal_conv1d.py
View file @
217ee621
File moved
tests/kernels/test_flashinfer.py
→
tests/kernels/
un
test_flashinfer.py
View file @
217ee621
File moved
tests/kernels/test_fp8_quant.py
→
tests/kernels/
un
test_fp8_quant.py
View file @
217ee621
File moved
tests/kernels/test_ggml.py
→
tests/kernels/
un
test_ggml.py
View file @
217ee621
File moved
tests/kernels/test_gguf.py
→
tests/kernels/
un
test_gguf.py
View file @
217ee621
File moved
tests/kernels/test_gptq.py
→
tests/kernels/
un
test_gptq.py
View file @
217ee621
File moved
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment