Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e150cf11
Commit
e150cf11
authored
Dec 05, 2024
by
zhuwenwen
Browse files
added support for kernels tests with torch 2.3
parent
a3d96521
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
753 additions
and
415 deletions
+753
-415
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+31
-14
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+132
-63
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+3
-5
tests/kernels/test_awq_triton.py
tests/kernels/test_awq_triton.py
+6
-1
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+254
-148
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+41
-23
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+35
-31
tests/kernels/test_int8_quant.py
tests/kernels/test_int8_quant.py
+75
-42
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+21
-11
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+43
-21
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+64
-28
tests/kernels/test_rotary_embedding.py
tests/kernels/test_rotary_embedding.py
+6
-1
tests/kernels/test_utils.py
tests/kernels/test_utils.py
+8
-5
tests/kernels/untest_permute_cols.py
tests/kernels/untest_permute_cols.py
+14
-5
tests/kernels/utils.py
tests/kernels/utils.py
+20
-17
No files found.
tests/kernels/test_activation.py
View file @
e150cf11
...
...
@@ -3,13 +3,13 @@ from typing import Type
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.activation
import
(
FastGELU
,
GeluAndMul
,
NewGELU
,
QuickGELU
,
SiluAndMul
)
from
vllm.utils
import
seed_everything
from
.allclose_default
import
get_default_atol
,
get_default_rtol
from
.utils
import
torch_version
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
...
...
@@ -49,6 +49,11 @@ def test_act_and_mul(
fn
=
torch
.
ops
.
_C
.
gelu_tanh_and_mul
out
=
layer
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
...
...
@@ -57,6 +62,8 @@ def test_act_and_mul(
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
opcheck
(
fn
,
(
out
,
x
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
[(
FastGELU
,
torch
.
ops
.
_C
.
gelu_fast
),
...
...
@@ -83,6 +90,14 @@ def test_activation(
fn
=
activation
[
1
]
out
=
layer
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
get_default_atol
(
out
),
rtol
=
get_default_rtol
(
out
))
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
get_default_atol
(
out
),
...
...
@@ -90,3 +105,5 @@ def test_activation(
out
=
torch
.
empty_like
(
x
)
opcheck
(
fn
,
(
out
,
x
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
tests/kernels/test_attention.py
View file @
e150cf11
...
...
@@ -4,11 +4,11 @@ from typing import List, Optional, Tuple
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
get_max_shared_memory_bytes
,
is_hip
,
seed_everything
from
.allclose_default
import
get_default_atol
,
get_default_rtol
from
.utils
import
torch_version
if
not
is_hip
():
from
xformers
import
ops
as
xops
...
...
@@ -186,6 +186,25 @@ def test_paged_attention(
# Call the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
if
version
==
"v1"
:
if
torch_version
.
startswith
(
"2.3"
):
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
ops
.
paged_attention_v1
(
output
,
query
,
...
...
@@ -209,6 +228,8 @@ def test_paged_attention(
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
,
None
,
0
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
and
block_size
==
BLOCK_SIZES
[
0
]))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
elif
version
in
(
"v2"
,
"rocm"
):
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
...
...
@@ -224,6 +245,28 @@ def test_paged_attention(
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
if
version
==
"v2"
:
if
torch_version
.
startswith
(
"2.3"
):
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
ops
.
paged_attention_v2
(
output
,
exp_sums
,
...
...
@@ -251,8 +294,32 @@ def test_paged_attention(
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
,
None
,
0
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
and
block_size
==
BLOCK_SIZES
[
0
]))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
else
:
if
torch_version
.
startswith
(
"2.3"
):
ops
.
paged_attention_rocm
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
ops
.
paged_attention_rocm
(
output
,
exp_sums
,
...
...
@@ -280,6 +347,8 @@ def test_paged_attention(
kv_cache_dtype
,
k_scale
,
v_scale
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
and
block_size
==
BLOCK_SIZES
[
0
]))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
else
:
raise
AssertionError
(
f
"Unknown version:
{
version
}
"
)
...
...
tests/kernels/test_attention_selector.py
View file @
e150cf11
...
...
@@ -6,14 +6,12 @@ import torch
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm.attention.selector
import
which_attn_to_use
from
vllm.utils
import
STR_FLASH_ATTN_VAL
,
STR_INVALID_VAL
from
vllm.utils
import
is_hip
# @pytest.mark.parametrize(
# "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
# @pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
@
pytest
.
mark
.
parametrize
(
"name"
,
[
"ROCM_FLASH"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
,
"OPENVINO"
]
if
not
is_hip
()
else
[
"ROCM_FLASH"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"openvino"
,
"hip"
,
"cuda"
])
def
test_env
(
name
:
str
,
device
:
str
,
monkeypatch
):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
...
...
tests/kernels/test_awq_triton.py
View file @
e150cf11
...
...
@@ -8,6 +8,7 @@ import torch
from
vllm.model_executor.layers.quantization.awq_triton
import
(
AWQ_TRITON_SUPPORTED_GROUP_SIZES
,
awq_dequantize_triton
,
awq_gemm_triton
)
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
device
=
"cuda"
...
...
@@ -64,6 +65,8 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
# qweights - [R , C // 8], int32
# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
@
pytest
.
mark
.
skipif
(
torch_version
.
startswith
(
"2.3"
),
reason
=
"Need triton3.0."
)
@
pytest
.
mark
.
parametrize
(
"qweight_rows"
,
[
3584
,
18944
,
128
,
256
,
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"qweight_cols"
,
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
AWQ_TRITON_SUPPORTED_GROUP_SIZES
)
...
...
@@ -111,6 +114,8 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
# qweight - [K, M // 8]
# qzeros - [K // G, M // 8]
# scales - [K // G, M]
@
pytest
.
mark
.
skipif
(
torch_version
.
startswith
(
"2.3"
),
reason
=
"Need triton3.0."
)
@
pytest
.
mark
.
parametrize
(
"N"
,
[
1
,
2
,
4
,
8
,
14
,
17
,
23
,
32
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"M"
,
[
16
,
24
,
32
])
...
...
tests/kernels/test_cache.py
View file @
e150cf11
...
...
@@ -4,9 +4,11 @@ from typing import List, Tuple
import
pytest
import
torch
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
,
opcheck
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
seed_everything
from
vllm.utils
import
is_hip
from
.utils
import
torch_version
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
...
@@ -88,6 +90,23 @@ def test_copy_blocks(
dtype
=
torch
.
int64
,
device
=
device
).
view
(
-
1
,
2
)
if
torch_version
.
startswith
(
"2.3"
):
ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping_tensor
)
for
src
,
dst
in
block_mapping
:
for
cloned_key_cache
in
cloned_key_caches
:
cloned_key_cache
[
dst
].
copy_
(
cloned_key_cache
[
src
])
for
cloned_value_cache
in
cloned_value_caches
:
cloned_value_cache
[
dst
].
copy_
(
cloned_value_cache
[
src
])
# Compare the results.
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
cloned_value_caches
):
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
_C_cache_ops
.
copy_blocks
,
(
key_caches
,
value_caches
,
block_mapping_tensor
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
...
...
@@ -107,6 +126,8 @@ def test_copy_blocks(
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
cloned_value_caches
):
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
...
@@ -163,6 +184,45 @@ def test_reshape_and_cache(
# Using default kv_scale
k_scale
=
v_scale
=
1.0
if
torch_version
.
startswith
(
"2.3"
):
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indicies
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies
[
i
]
block_offset
=
block_offsets
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
assert
torch
.
allclose
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
assert
torch
.
allclose
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
else
:
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
...
...
@@ -201,6 +261,8 @@ def test_reshape_and_cache(
else
:
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
...
@@ -272,6 +334,30 @@ def test_reshape_and_cache_flash(
# Using default kv_scale
k_scale
=
v_scale
=
1.0
if
torch_version
.
startswith
(
"2.3"
):
# Clone the KV caches.
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Call the reshape_and_cache kernel.
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
# Run the reference implementation.
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
'floor'
)
block_indicies
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies
[
i
]
block_offset
=
block_offsets
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
...
...
@@ -309,6 +395,8 @@ def test_reshape_and_cache_flash(
else
:
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
...
...
@@ -371,6 +459,20 @@ def test_swap_blocks(
src_key_caches_clone
=
src_key_caches
[
0
].
clone
()
src_value_caches_clone
=
src_value_caches
[
0
].
clone
()
if
torch_version
.
startswith
(
"2.3"
):
# Call the swap_blocks kernel.
ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping_tensor
)
ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping_tensor
)
for
src
,
dst
in
block_mapping
:
assert
torch
.
allclose
(
src_key_caches_clone
[
src
].
cpu
(),
dist_key_caches
[
0
][
dst
].
cpu
())
assert
torch
.
allclose
(
src_value_caches_clone
[
src
].
cpu
(),
dist_value_caches
[
0
][
dst
].
cpu
())
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
# Call the swap_blocks kernel.
do_opcheck
=
(
head_size
==
HEAD_SIZES
[
0
])
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
...
...
@@ -390,37 +492,41 @@ def test_swap_blocks(
dist_key_caches
[
0
][
dst
].
cpu
())
torch
.
testing
.
assert_close
(
src_value_caches_clone
[
src
].
cpu
(),
dist_value_caches
[
0
][
dst
].
cpu
())
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"FP8 is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_fp8_e4m3_conversion
(
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
)
->
None
:
seed_everything
(
seed
)
low
=
-
224.0
high
=
224.0
shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
cache
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
)
cache
.
uniform_
(
low
,
high
)
cache_fp8
=
torch
.
empty_like
(
cache
,
dtype
=
torch
.
uint8
)
ops
.
convert_fp8
(
cache_fp8
,
cache
)
converted_cache
=
torch
.
empty_like
(
cache
)
ops
.
convert_fp8
(
converted_cache
,
cache_fp8
)
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("block_size", BLOCK_SIZES)
# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @torch.inference_mode()
# def test_fp8_e4m3_conversion(
# num_heads: int,
# head_size: int,
# block_size: int,
# num_blocks: int,
# dtype: torch.dtype,
# seed: int,
# device: str,
# ) -> None:
# seed_everything(seed)
# low = -224.0
# high = 224.0
# shape = (num_blocks, num_heads, head_size, block_size)
# cache = torch.empty(shape, dtype=dtype, device=device)
# cache.uniform_(low, high)
# cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
# ops.convert_fp8(cache_fp8, cache)
# converted_cache = torch.empty_like(cache)
# ops.convert_fp8(converted_cache, cache_fp8)
# torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
torch
.
testing
.
assert_close
(
cache
,
converted_cache
,
atol
=
0.001
,
rtol
=
0.1
)
tests/kernels/test_cutlass.py
View file @
e150cf11
...
...
@@ -7,16 +7,16 @@ from typing import Optional, Type
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
#from vllm.platforms import current_platform
from
vllm.platforms
import
current_platform
from
.utils
import
torch_version
CUDA_DEVICES
=
[
f
"cuda:
{
0
}
"
#
for i in range(1 if torch.cuda.device_count() == 1 else 2)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
#
capability = current_platform.get_device_capability()
capability
=
90
#
capability[0] * 10 + capability[1]
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
def
to_fp8
(
tensor
:
torch
.
Tensor
):
...
...
@@ -75,10 +75,16 @@ def cutlass_fp8_gemm_helper(m: int,
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
1e-1
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
5e-2
)
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
def
cutlass_int8_gemm_helper
(
m
:
int
,
...
...
@@ -116,10 +122,15 @@ def cutlass_int8_gemm_helper(m: int,
# print("out:",out[0:5][0:5])
# print("baseline:",baseline[0:5][0:5])
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
#
opcheck(torch.ops._C.cutlass_scaled_mm,
#
(out, a, b, scale_a, scale_b, bias))
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
# @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
...
...
@@ -350,18 +361,25 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
# # bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
# # float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
# rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
# atol = 1e-3
# torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
# torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
# if azp_per_token:
# opcheck(torch.ops._C.cutlass_scaled_mm_azp,
# (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
# func_bias))
# else:
# opcheck(torch.ops._C.cutlass_scaled_mm_azp,
# (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
# func_bias))
# atol = 1e-3
# if torch_version.startswith("2.3"):
# assert torch.allclose(out, baseline_dq, rtol=rtol, atol=atol)
# assert torch.allclose(out, baseline_q, rtol=rtol, atol=atol)
# elif torch_version.startswith("2.4"):
# from tests.kernels.utils import opcheck
# torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
# torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
# if azp_per_token:
# opcheck(torch.ops._C.cutlass_scaled_mm_azp,
# (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
# func_bias))
# else:
# opcheck(torch.ops._C.cutlass_scaled_mm_azp,
# (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
# func_bias))
# else:
# print(f"PyTorch version {torch_version} is not specifically handled.")
# Test working with a subset of A and B
...
...
tests/kernels/test_flash_attn.py
View file @
e150cf11
...
...
@@ -8,8 +8,8 @@ if is_hip():
import
flash_attn
else
:
import
vllm.attention.backends.flash_attn
# noqa: F401
from
tests.kernels.utils
import
opcheck
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
),
(
16
,
2
)]
HEAD_SIZES
=
[
128
,
256
]
...
...
@@ -132,6 +132,8 @@ if not is_hip():
else
:
test_utils
=
[
"test_faketensor"
]
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
,
args
=
tuple
(),
kwargs
=
dict
(
...
...
@@ -253,6 +255,8 @@ def test_varlen_with_paged_kv(
test_utils
=
[
"test_faketensor"
]
if
not
is_hip
():
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_varlen_func
,
args
=
tuple
(),
kwargs
=
dict
(
...
...
tests/kernels/test_int8_quant.py
View file @
e150cf11
...
...
@@ -2,10 +2,10 @@ import pytest
import
torch
from
tests.kernels.quant_utils
import
ref_dynamic_per_token_quant
from
tests.kernels.utils
import
opcheck
from
vllm._custom_ops
import
scaled_int8_quant
from
vllm.utils
import
seed_everything
from
vllm.utils
import
is_hip
from
.utils
import
torch_version
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HIDDEN_SIZES
=
[
16
,
67
,
768
,
2048
,
5120
,
5137
,
8192
,
...
...
@@ -15,7 +15,11 @@ SEEDS = [0]
SCALE
=
[
0.1
,
0.5
,
0.8
,
1.2
,
2.1
]
def
opcheck_int8_quant_static
(
output
,
input
,
scale
,
azp
=
None
):
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
def
opcheck_int8_quant_static
(
output
,
input
,
scale
,
azp
=
None
):
if
azp
is
None
:
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
(
output
,
input
,
scale
,
None
))
...
...
@@ -24,7 +28,7 @@ def opcheck_int8_quant_static(output, input, scale, azp=None):
(
output
,
input
,
scale
,
azp
))
def
opcheck_int8_quant_dynamic
(
output
,
input
,
symmetric
=
True
):
def
opcheck_int8_quant_dynamic
(
output
,
input
,
symmetric
=
True
):
scale
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
...
...
@@ -56,11 +60,18 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
# kernel
ops_out
,
ops_scales
,
_
=
scaled_int8_quant
(
x
)
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
ops_scales
,
ref_scales
)
torch
.
allclose
(
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
ops_scales
,
ref_scales
)
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Currently, there is not supported on ROCm."
)
...
...
@@ -97,6 +108,11 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
if
(
not
torch
.
allclose
(
scales_out
,
scales
)):
print
(
torch
.
argmax
(
torch
.
abs
(
scales_out
-
scales
)))
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
scales_out
,
scales
)
torch
.
allclose
(
azp_out
,
azps
,
atol
=
1
,
rtol
=
0.0
)
torch
.
allclose
(
ops_out
,
torch_out
,
atol
=
2
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
scales_out
,
scales
)
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
azp_out
,
azps
,
atol
=
1
,
rtol
=
0.0
)
...
...
@@ -104,6 +120,8 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
torch
.
testing
.
assert_close
(
ops_out
,
torch_out
,
atol
=
2
,
rtol
=
0.0
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
,
False
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
...
@@ -125,10 +143,15 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits
.
max
).
to
(
torch
.
int8
)
out2
,
_
,
_
=
scaled_int8_quant
(
x
,
scale_arg
)
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
...
@@ -155,10 +178,15 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out2
,
x
,
scale_arg
,
azp_arg
)
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
,
azp_arg
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"is_max"
,
[
True
,
False
])
...
...
@@ -190,4 +218,9 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
out
=
torch
.
empty_like
(
expected
)
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out
,
x
,
scale
,
azp
)
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
expected
,
out
,
atol
=
0
,
rtol
=
0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
expected
,
out
,
atol
=
0
,
rtol
=
0
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
tests/kernels/test_layernorm.py
View file @
e150cf11
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
4096
]
# Arbitrary values for testing
...
...
@@ -47,6 +47,14 @@ def test_rms_norm(
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.
# Therefore, we use a larger tolerance.
if
torch_version
.
startswith
(
"2.3"
):
if
add_residual
:
torch
.
allclose
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
allclose
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
if
add_residual
:
torch
.
testing
.
assert_close
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
...
...
@@ -59,3 +67,5 @@ def test_rms_norm(
else
:
opcheck
(
torch
.
ops
.
_C
.
rms_norm
,
(
out
,
x
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
tests/kernels/test_moe.py
View file @
e150cf11
...
...
@@ -9,7 +9,6 @@ import torch
from
transformers
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
...
...
@@ -22,6 +21,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
from
vllm.utils
import
is_hip
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
...
...
@@ -76,7 +77,12 @@ def test_fused_moe(
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
...
...
@@ -120,11 +126,18 @@ def test_mixtral_moe(dtype: torch.dtype):
torch
.
float16
:
1e-3
,
torch
.
bfloat16
:
1e-2
,
}
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
hf_states
.
flatten
(
0
,
1
),
vllm_states
,
rtol
=
mixtral_moe_tol
[
dtype
],
atol
=
mixtral_moe_tol
[
dtype
])
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
hf_states
.
flatten
(
0
,
1
),
vllm_states
,
rtol
=
mixtral_moe_tol
[
dtype
],
atol
=
mixtral_moe_tol
[
dtype
])
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
def
stack_and_dev
(
tensors
:
List
[
torch
.
Tensor
]):
...
...
@@ -137,6 +150,8 @@ def compute_max_diff(output, output_ref):
torch
.
abs
(
output_ref
))
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Currently, there is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
,
512
])
...
...
@@ -256,6 +271,8 @@ def test_fused_marlin_moe(
dtype
=
torch
.
int32
,
device
=
a
.
device
)
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
_moe_C
.
topk_softmax
,
(
topk_weights
,
topk_ids
,
...
...
@@ -274,12 +291,16 @@ def test_fused_marlin_moe(
device
=
"cuda"
,
requires_grad
=
False
)
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
,
(
a
,
qweight1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales1
,
g_idx1
,
sort_indices1
,
workspace
,
quant_type
,
m
,
2
*
n
,
k
,
True
,
e
,
topk
,
block_size_m
,
True
,
False
))
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Currently, there is not supported on ROCm."
)
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
"don't run it in automated tests."
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
...
...
@@ -373,7 +394,8 @@ def test_moe_align_block_size_opcheck():
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
_C
.
moe_align_block_size
,
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
))
tests/kernels/test_pos_encoding.py
View file @
e150cf11
...
...
@@ -8,6 +8,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.utils
import
seed_everything
from
.allclose_default
import
get_default_atol
,
get_default_rtol
from
.utils
import
torch_version
IS_NEOX_STYLE
=
[
True
,
False
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
...
@@ -18,7 +19,7 @@ BATCH_SIZES = [1, 5] # Arbitrary values for testing
SEQ_LENS
=
[
11
,
8192
]
# Arbitrary values for testing
SEEDS
=
[
0
]
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
1
)
]
...
...
@@ -67,6 +68,16 @@ def test_rotary_embedding(
ref_query
,
ref_key
=
rope
.
forward_native
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
)
# Compare the results.
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
...
...
@@ -75,6 +86,8 @@ def test_rotary_embedding(
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
...
@@ -126,6 +139,16 @@ def test_batched_rotary_embedding(
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
dtype
=
torch
.
long
,
device
=
device
))
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
elif
torch_version
.
startswith
(
"2.4"
):
# Compare the results.
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
...
...
@@ -135,6 +158,8 @@ def test_batched_rotary_embedding(
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
...
@@ -195,6 +220,16 @@ def test_batched_rotary_embedding_multi_lora(
query_offsets
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
,
query_offsets
.
flatten
())
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
elif
torch_version
.
startswith
(
"2.4"
):
# Compare the results.
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
...
...
@@ -204,7 +239,8 @@ def test_batched_rotary_embedding_multi_lora(
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
torch
.
inference_mode
()
def
test_rope_module_cache
():
...
...
tests/kernels/test_rotary_embedding.py
View file @
e150cf11
...
...
@@ -7,8 +7,11 @@ from typing import Optional
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
.utils
import
torch_version
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
def
rotary_embedding_opcheck
(
rot
,
...
...
@@ -30,6 +33,8 @@ def rotary_embedding_opcheck(rot,
rot
.
is_neox_style
))
@
pytest
.
mark
.
skipif
(
torch_version
.
startswith
(
"2.3"
),
reason
=
"Need torch2.4."
)
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"max_position"
,
[
11
,
4096
,
32768
])
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
[
True
,
False
])
...
...
tests/kernels/test_utils.py
View file @
e150cf11
...
...
@@ -5,14 +5,17 @@ Tests for miscellaneous utilities
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.platforms
import
current_platform
from
.utils
import
torch_version
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
def
test_convert_fp8_opcheck
():
data
=
torch
.
randn
((
256
,
256
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
result
=
torch
.
empty_like
(
data
,
dtype
=
torch
.
float8_e4m3fn
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
convert_fp8
,
(
result
,
data
,
1.0
,
"fp8"
))
# def test_convert_fp8_opcheck():
# data = torch.randn((256, 256), dtype=torch.float32, device="cuda")
# result = torch.empty_like(data, dtype=torch.float8_e4m3fn)
# opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
...
...
tests/kernels/untest_permute_cols.py
View file @
e150cf11
...
...
@@ -3,13 +3,22 @@ import torch
from
tests.kernels.utils
import
opcheck
from
vllm._custom_ops
import
permute_cols
from
.utils
import
torch_version
@
pytest
.
mark
.
parametrize
(
'shape'
,
[(
1
,
512
),
(
544
,
4096
),
(
67
,
8192
)])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
bfloat16
,
torch
.
float16
])
def
test_permute_cols
(
shape
,
dtype
):
if
torch_version
.
startswith
(
"2.3"
):
x
=
torch
.
randn
(
shape
,
dtype
=
dtype
).
cuda
()
perm
=
torch
.
randperm
(
x
.
shape
[
1
]).
to
(
torch
.
int
).
cuda
()
y
=
permute_cols
(
x
,
perm
)
torch
.
allclose
(
y
,
x
[:,
perm
])
elif
torch_version
.
startswith
(
"2.4"
):
x
=
torch
.
randn
(
shape
,
dtype
=
dtype
).
cuda
()
perm
=
torch
.
randperm
(
x
.
shape
[
1
]).
to
(
torch
.
int
).
cuda
()
opcheck
(
torch
.
ops
.
_C
.
permute_cols
,
(
x
,
perm
))
y
=
permute_cols
(
x
,
perm
)
torch
.
testing
.
assert_close
(
y
,
x
[:,
perm
])
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
\ No newline at end of file
tests/kernels/utils.py
View file @
e150cf11
...
...
@@ -30,6 +30,8 @@ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
"test_aot_dispatch_dynamic"
,
)
torch_version
=
torch
.
__version__
class
QKVInputs
(
NamedTuple
):
'''
...
...
@@ -974,9 +976,10 @@ def fp8_allclose(
equal_nan
=
equal_nan
)).
item
())
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def
opcheck
(
op
:
Union
[
torch
.
_ops
.
OpOverload
,
torch
.
_ops
.
OpOverloadPacket
,
if
torch_version
.
startswith
(
"2.4"
):
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def
opcheck
(
op
:
Union
[
torch
.
_ops
.
OpOverload
,
torch
.
_ops
.
OpOverloadPacket
,
torch
.
_library
.
custom_ops
.
CustomOpDef
],
args
:
Tuple
[
Any
,
...],
kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment