Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e150cf11
Commit
e150cf11
authored
Dec 05, 2024
by
zhuwenwen
Browse files
added support for kernels tests with torch 2.3
parent
a3d96521
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
753 additions
and
415 deletions
+753
-415
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+31
-14
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+132
-63
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+3
-5
tests/kernels/test_awq_triton.py
tests/kernels/test_awq_triton.py
+6
-1
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+254
-148
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+41
-23
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+35
-31
tests/kernels/test_int8_quant.py
tests/kernels/test_int8_quant.py
+75
-42
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+21
-11
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+43
-21
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+64
-28
tests/kernels/test_rotary_embedding.py
tests/kernels/test_rotary_embedding.py
+6
-1
tests/kernels/test_utils.py
tests/kernels/test_utils.py
+8
-5
tests/kernels/untest_permute_cols.py
tests/kernels/untest_permute_cols.py
+14
-5
tests/kernels/utils.py
tests/kernels/utils.py
+20
-17
No files found.
tests/kernels/test_activation.py
View file @
e150cf11
...
@@ -3,13 +3,13 @@ from typing import Type
...
@@ -3,13 +3,13 @@ from typing import Type
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.activation
import
(
FastGELU
,
GeluAndMul
,
from
vllm.model_executor.layers.activation
import
(
FastGELU
,
GeluAndMul
,
NewGELU
,
QuickGELU
,
NewGELU
,
QuickGELU
,
SiluAndMul
)
SiluAndMul
)
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
.allclose_default
import
get_default_atol
,
get_default_rtol
from
.allclose_default
import
get_default_atol
,
get_default_rtol
from
.utils
import
torch_version
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
...
@@ -49,6 +49,11 @@ def test_act_and_mul(
...
@@ -49,6 +49,11 @@ def test_act_and_mul(
fn
=
torch
.
ops
.
_C
.
gelu_tanh_and_mul
fn
=
torch
.
ops
.
_C
.
gelu_tanh_and_mul
out
=
layer
(
x
)
out
=
layer
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
# The SiLU and GELU implementations are equivalent to the native PyTorch
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
# implementations, so we can do exact comparison.
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
...
@@ -57,6 +62,8 @@ def test_act_and_mul(
...
@@ -57,6 +62,8 @@ def test_act_and_mul(
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
opcheck
(
fn
,
(
out
,
x
))
opcheck
(
fn
,
(
out
,
x
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
[(
FastGELU
,
torch
.
ops
.
_C
.
gelu_fast
),
@
pytest
.
mark
.
parametrize
(
"activation"
,
[(
FastGELU
,
torch
.
ops
.
_C
.
gelu_fast
),
...
@@ -83,6 +90,14 @@ def test_activation(
...
@@ -83,6 +90,14 @@ def test_activation(
fn
=
activation
[
1
]
fn
=
activation
[
1
]
out
=
layer
(
x
)
out
=
layer
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
get_default_atol
(
out
),
rtol
=
get_default_rtol
(
out
))
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
torch
.
testing
.
assert_close
(
out
,
torch
.
testing
.
assert_close
(
out
,
ref_out
,
ref_out
,
atol
=
get_default_atol
(
out
),
atol
=
get_default_atol
(
out
),
...
@@ -90,3 +105,5 @@ def test_activation(
...
@@ -90,3 +105,5 @@ def test_activation(
out
=
torch
.
empty_like
(
x
)
out
=
torch
.
empty_like
(
x
)
opcheck
(
fn
,
(
out
,
x
))
opcheck
(
fn
,
(
out
,
x
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
tests/kernels/test_attention.py
View file @
e150cf11
...
@@ -4,11 +4,11 @@ from typing import List, Optional, Tuple
...
@@ -4,11 +4,11 @@ from typing import List, Optional, Tuple
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
get_max_shared_memory_bytes
,
is_hip
,
seed_everything
from
vllm.utils
import
get_max_shared_memory_bytes
,
is_hip
,
seed_everything
from
.allclose_default
import
get_default_atol
,
get_default_rtol
from
.allclose_default
import
get_default_atol
,
get_default_rtol
from
.utils
import
torch_version
if
not
is_hip
():
if
not
is_hip
():
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
...
@@ -186,6 +186,25 @@ def test_paged_attention(
...
@@ -186,6 +186,25 @@ def test_paged_attention(
# Call the paged attention kernel.
# Call the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
if
version
==
"v1"
:
if
version
==
"v1"
:
if
torch_version
.
startswith
(
"2.3"
):
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
ops
.
paged_attention_v1
(
ops
.
paged_attention_v1
(
output
,
output
,
query
,
query
,
...
@@ -209,6 +228,8 @@ def test_paged_attention(
...
@@ -209,6 +228,8 @@ def test_paged_attention(
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
,
None
,
0
),
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
,
None
,
0
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
and
block_size
==
BLOCK_SIZES
[
0
]))
and
block_size
==
BLOCK_SIZES
[
0
]))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
elif
version
in
(
"v2"
,
"rocm"
):
elif
version
in
(
"v2"
,
"rocm"
):
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
...
@@ -224,6 +245,28 @@ def test_paged_attention(
...
@@ -224,6 +245,28 @@ def test_paged_attention(
)
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
if
version
==
"v2"
:
if
version
==
"v2"
:
if
torch_version
.
startswith
(
"2.3"
):
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
ops
.
paged_attention_v2
(
ops
.
paged_attention_v2
(
output
,
output
,
exp_sums
,
exp_sums
,
...
@@ -251,8 +294,32 @@ def test_paged_attention(
...
@@ -251,8 +294,32 @@ def test_paged_attention(
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
,
None
,
0
),
kv_cache_dtype
,
k_scale
,
v_scale
,
0
,
0
,
0
,
64
,
0
,
None
,
0
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
and
block_size
==
BLOCK_SIZES
[
0
]))
and
block_size
==
BLOCK_SIZES
[
0
]))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
else
:
else
:
if
torch_version
.
startswith
(
"2.3"
):
ops
.
paged_attention_rocm
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
ops
.
paged_attention_rocm
(
ops
.
paged_attention_rocm
(
output
,
output
,
exp_sums
,
exp_sums
,
...
@@ -280,6 +347,8 @@ def test_paged_attention(
...
@@ -280,6 +347,8 @@ def test_paged_attention(
kv_cache_dtype
,
k_scale
,
v_scale
),
kv_cache_dtype
,
k_scale
,
v_scale
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
cond
=
(
head_size
==
HEAD_SIZES
[
0
]
and
block_size
==
BLOCK_SIZES
[
0
]))
and
block_size
==
BLOCK_SIZES
[
0
]))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
else
:
else
:
raise
AssertionError
(
f
"Unknown version:
{
version
}
"
)
raise
AssertionError
(
f
"Unknown version:
{
version
}
"
)
...
...
tests/kernels/test_attention_selector.py
View file @
e150cf11
...
@@ -6,14 +6,12 @@ import torch
...
@@ -6,14 +6,12 @@ import torch
from
tests.kernels.utils
import
override_backend_env_variable
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm.attention.selector
import
which_attn_to_use
from
vllm.attention.selector
import
which_attn_to_use
from
vllm.utils
import
STR_FLASH_ATTN_VAL
,
STR_INVALID_VAL
from
vllm.utils
import
STR_FLASH_ATTN_VAL
,
STR_INVALID_VAL
from
vllm.utils
import
is_hip
# @pytest.mark.parametrize(
# "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
# @pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"name"
,
[
"ROCM_FLASH"
])
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
,
"OPENVINO"
]
if
not
is_hip
()
else
[
"ROCM_FLASH"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"openvino"
,
"hip"
,
"cuda"
])
def
test_env
(
name
:
str
,
device
:
str
,
monkeypatch
):
def
test_env
(
name
:
str
,
device
:
str
,
monkeypatch
):
"""Test that the attention selector can be set via environment variable.
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
Note that we do not test FlashAttn because it is the default backend.
...
...
tests/kernels/test_awq_triton.py
View file @
e150cf11
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
from
vllm.model_executor.layers.quantization.awq_triton
import
(
from
vllm.model_executor.layers.quantization.awq_triton
import
(
AWQ_TRITON_SUPPORTED_GROUP_SIZES
,
awq_dequantize_triton
,
awq_gemm_triton
)
AWQ_TRITON_SUPPORTED_GROUP_SIZES
,
awq_dequantize_triton
,
awq_gemm_triton
)
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
device
=
"cuda"
device
=
"cuda"
...
@@ -64,6 +65,8 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
...
@@ -64,6 +65,8 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
# qweights - [R , C // 8], int32
# qweights - [R , C // 8], int32
# scales - [R // G, C ], float16
# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
# zeros - [R // G, C // 8], int32
@
pytest
.
mark
.
skipif
(
torch_version
.
startswith
(
"2.3"
),
reason
=
"Need triton3.0."
)
@
pytest
.
mark
.
parametrize
(
"qweight_rows"
,
[
3584
,
18944
,
128
,
256
,
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"qweight_rows"
,
[
3584
,
18944
,
128
,
256
,
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"qweight_cols"
,
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"qweight_cols"
,
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
AWQ_TRITON_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
AWQ_TRITON_SUPPORTED_GROUP_SIZES
)
...
@@ -111,6 +114,8 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
...
@@ -111,6 +114,8 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
# qweight - [K, M // 8]
# qweight - [K, M // 8]
# qzeros - [K // G, M // 8]
# qzeros - [K // G, M // 8]
# scales - [K // G, M]
# scales - [K // G, M]
@
pytest
.
mark
.
skipif
(
torch_version
.
startswith
(
"2.3"
),
reason
=
"Need triton3.0."
)
@
pytest
.
mark
.
parametrize
(
"N"
,
[
1
,
2
,
4
,
8
,
14
,
17
,
23
,
32
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
1
,
2
,
4
,
8
,
14
,
17
,
23
,
32
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"M"
,
[
16
,
24
,
32
])
@
pytest
.
mark
.
parametrize
(
"M"
,
[
16
,
24
,
32
])
...
...
tests/kernels/test_cache.py
View file @
e150cf11
...
@@ -4,9 +4,11 @@ from typing import List, Tuple
...
@@ -4,9 +4,11 @@ from typing import List, Tuple
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
,
opcheck
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
vllm.utils
import
is_hip
from
.utils
import
torch_version
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
@@ -88,6 +90,23 @@ def test_copy_blocks(
...
@@ -88,6 +90,23 @@ def test_copy_blocks(
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
device
).
view
(
-
1
,
2
)
device
=
device
).
view
(
-
1
,
2
)
if
torch_version
.
startswith
(
"2.3"
):
ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping_tensor
)
for
src
,
dst
in
block_mapping
:
for
cloned_key_cache
in
cloned_key_caches
:
cloned_key_cache
[
dst
].
copy_
(
cloned_key_cache
[
src
])
for
cloned_value_cache
in
cloned_value_caches
:
cloned_value_cache
[
dst
].
copy_
(
cloned_value_cache
[
src
])
# Compare the results.
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
cloned_value_caches
):
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
_C_cache_ops
.
copy_blocks
,
opcheck
(
torch
.
ops
.
_C_cache_ops
.
copy_blocks
,
(
key_caches
,
value_caches
,
block_mapping_tensor
),
(
key_caches
,
value_caches
,
block_mapping_tensor
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
...
@@ -107,6 +126,8 @@ def test_copy_blocks(
...
@@ -107,6 +126,8 @@ def test_copy_blocks(
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
cloned_value_caches
):
cloned_value_caches
):
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
@@ -163,6 +184,45 @@ def test_reshape_and_cache(
...
@@ -163,6 +184,45 @@ def test_reshape_and_cache(
# Using default kv_scale
# Using default kv_scale
k_scale
=
v_scale
=
1.0
k_scale
=
v_scale
=
1.0
if
torch_version
.
startswith
(
"2.3"
):
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indicies
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies
[
i
]
block_offset
=
block_offsets
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
assert
torch
.
allclose
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
assert
torch
.
allclose
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
else
:
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
# Call the reshape_and_cache kernel.
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
...
@@ -201,6 +261,8 @@ def test_reshape_and_cache(
...
@@ -201,6 +261,8 @@ def test_reshape_and_cache(
else
:
else
:
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
@@ -272,6 +334,30 @@ def test_reshape_and_cache_flash(
...
@@ -272,6 +334,30 @@ def test_reshape_and_cache_flash(
# Using default kv_scale
# Using default kv_scale
k_scale
=
v_scale
=
1.0
k_scale
=
v_scale
=
1.0
if
torch_version
.
startswith
(
"2.3"
):
# Clone the KV caches.
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Call the reshape_and_cache kernel.
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
# Run the reference implementation.
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
'floor'
)
block_indicies
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies
[
i
]
block_offset
=
block_offsets
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
# Call the reshape_and_cache kernel.
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
...
@@ -309,6 +395,8 @@ def test_reshape_and_cache_flash(
...
@@ -309,6 +395,8 @@ def test_reshape_and_cache_flash(
else
:
else
:
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
...
@@ -371,6 +459,20 @@ def test_swap_blocks(
...
@@ -371,6 +459,20 @@ def test_swap_blocks(
src_key_caches_clone
=
src_key_caches
[
0
].
clone
()
src_key_caches_clone
=
src_key_caches
[
0
].
clone
()
src_value_caches_clone
=
src_value_caches
[
0
].
clone
()
src_value_caches_clone
=
src_value_caches
[
0
].
clone
()
if
torch_version
.
startswith
(
"2.3"
):
# Call the swap_blocks kernel.
ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping_tensor
)
ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping_tensor
)
for
src
,
dst
in
block_mapping
:
assert
torch
.
allclose
(
src_key_caches_clone
[
src
].
cpu
(),
dist_key_caches
[
0
][
dst
].
cpu
())
assert
torch
.
allclose
(
src_value_caches_clone
[
src
].
cpu
(),
dist_value_caches
[
0
][
dst
].
cpu
())
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
# Call the swap_blocks kernel.
# Call the swap_blocks kernel.
do_opcheck
=
(
head_size
==
HEAD_SIZES
[
0
])
do_opcheck
=
(
head_size
==
HEAD_SIZES
[
0
])
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
...
@@ -390,37 +492,41 @@ def test_swap_blocks(
...
@@ -390,37 +492,41 @@ def test_swap_blocks(
dist_key_caches
[
0
][
dst
].
cpu
())
dist_key_caches
[
0
][
dst
].
cpu
())
torch
.
testing
.
assert_close
(
src_value_caches_clone
[
src
].
cpu
(),
torch
.
testing
.
assert_close
(
src_value_caches_clone
[
src
].
cpu
(),
dist_value_caches
[
0
][
dst
].
cpu
())
dist_value_caches
[
0
][
dst
].
cpu
())
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"FP8 is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_fp8_e4m3_conversion
(
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
)
->
None
:
seed_everything
(
seed
)
low
=
-
224.0
high
=
224.0
shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
cache
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
)
cache
.
uniform_
(
low
,
high
)
cache_fp8
=
torch
.
empty_like
(
cache
,
dtype
=
torch
.
uint8
)
ops
.
convert_fp8
(
cache_fp8
,
cache
)
converted_cache
=
torch
.
empty_like
(
cache
)
ops
.
convert_fp8
(
converted_cache
,
cache_fp8
)
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
torch
.
testing
.
assert_close
(
cache
,
converted_cache
,
atol
=
0.001
,
rtol
=
0.1
)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("block_size", BLOCK_SIZES)
# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @torch.inference_mode()
# def test_fp8_e4m3_conversion(
# num_heads: int,
# head_size: int,
# block_size: int,
# num_blocks: int,
# dtype: torch.dtype,
# seed: int,
# device: str,
# ) -> None:
# seed_everything(seed)
# low = -224.0
# high = 224.0
# shape = (num_blocks, num_heads, head_size, block_size)
# cache = torch.empty(shape, dtype=dtype, device=device)
# cache.uniform_(low, high)
# cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
# ops.convert_fp8(cache_fp8, cache)
# converted_cache = torch.empty_like(cache)
# ops.convert_fp8(converted_cache, cache_fp8)
# torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
tests/kernels/test_cutlass.py
View file @
e150cf11
...
@@ -7,16 +7,16 @@ from typing import Optional, Type
...
@@ -7,16 +7,16 @@ from typing import Optional, Type
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
#from vllm.platforms import current_platform
from
vllm.platforms
import
current_platform
from
.utils
import
torch_version
CUDA_DEVICES
=
[
CUDA_DEVICES
=
[
f
"cuda:
{
0
}
"
#
for i in range(1 if torch.cuda.device_count() == 1 else 2)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
]
#
capability = current_platform.get_device_capability()
capability
=
current_platform
.
get_device_capability
()
capability
=
90
#
capability[0] * 10 + capability[1]
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
def
to_fp8
(
tensor
:
torch
.
Tensor
):
def
to_fp8
(
tensor
:
torch
.
Tensor
):
...
@@ -75,10 +75,16 @@ def cutlass_fp8_gemm_helper(m: int,
...
@@ -75,10 +75,16 @@ def cutlass_fp8_gemm_helper(m: int,
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
1e-1
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
5e-2
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
5e-2
)
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
def
cutlass_int8_gemm_helper
(
m
:
int
,
def
cutlass_int8_gemm_helper
(
m
:
int
,
...
@@ -116,10 +122,15 @@ def cutlass_int8_gemm_helper(m: int,
...
@@ -116,10 +122,15 @@ def cutlass_int8_gemm_helper(m: int,
# print("out:",out[0:5][0:5])
# print("out:",out[0:5][0:5])
# print("baseline:",baseline[0:5][0:5])
# print("baseline:",baseline[0:5][0:5])
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
#
opcheck(torch.ops._C.cutlass_scaled_mm,
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
#
(out, a, b, scale_a, scale_b, bias))
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
# @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
# @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
...
@@ -350,18 +361,25 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
...
@@ -350,18 +361,25 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
# # bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
# # bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
# # float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
# # float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
# rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
# rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
# atol = 1e-3
# atol = 1e-3
# torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
# if torch_version.startswith("2.3"):
# torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
# assert torch.allclose(out, baseline_dq, rtol=rtol, atol=atol)
# assert torch.allclose(out, baseline_q, rtol=rtol, atol=atol)
# if azp_per_token:
# elif torch_version.startswith("2.4"):
# opcheck(torch.ops._C.cutlass_scaled_mm_azp,
# from tests.kernels.utils import opcheck
# (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
# torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
# func_bias))
# torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
# else:
# opcheck(torch.ops._C.cutlass_scaled_mm_azp,
# if azp_per_token:
# (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
# opcheck(torch.ops._C.cutlass_scaled_mm_azp,
# func_bias))
# (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
# func_bias))
# else:
# opcheck(torch.ops._C.cutlass_scaled_mm_azp,
# (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
# func_bias))
# else:
# print(f"PyTorch version {torch_version} is not specifically handled.")
# Test working with a subset of A and B
# Test working with a subset of A and B
...
...
tests/kernels/test_flash_attn.py
View file @
e150cf11
...
@@ -8,8 +8,8 @@ if is_hip():
...
@@ -8,8 +8,8 @@ if is_hip():
import
flash_attn
import
flash_attn
else
:
else
:
import
vllm.attention.backends.flash_attn
# noqa: F401
import
vllm.attention.backends.flash_attn
# noqa: F401
from
tests.kernels.utils
import
opcheck
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
),
(
16
,
2
)]
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
),
(
16
,
2
)]
HEAD_SIZES
=
[
128
,
256
]
HEAD_SIZES
=
[
128
,
256
]
...
@@ -132,6 +132,8 @@ if not is_hip():
...
@@ -132,6 +132,8 @@ if not is_hip():
else
:
else
:
test_utils
=
[
"test_faketensor"
]
test_utils
=
[
"test_faketensor"
]
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
,
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
,
args
=
tuple
(),
args
=
tuple
(),
kwargs
=
dict
(
kwargs
=
dict
(
...
@@ -253,6 +255,8 @@ def test_varlen_with_paged_kv(
...
@@ -253,6 +255,8 @@ def test_varlen_with_paged_kv(
test_utils
=
[
"test_faketensor"
]
test_utils
=
[
"test_faketensor"
]
if
not
is_hip
():
if
not
is_hip
():
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_varlen_func
,
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_varlen_func
,
args
=
tuple
(),
args
=
tuple
(),
kwargs
=
dict
(
kwargs
=
dict
(
...
...
tests/kernels/test_int8_quant.py
View file @
e150cf11
...
@@ -2,10 +2,10 @@ import pytest
...
@@ -2,10 +2,10 @@ import pytest
import
torch
import
torch
from
tests.kernels.quant_utils
import
ref_dynamic_per_token_quant
from
tests.kernels.quant_utils
import
ref_dynamic_per_token_quant
from
tests.kernels.utils
import
opcheck
from
vllm._custom_ops
import
scaled_int8_quant
from
vllm._custom_ops
import
scaled_int8_quant
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
from
.utils
import
torch_version
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HIDDEN_SIZES
=
[
16
,
67
,
768
,
2048
,
5120
,
5137
,
8192
,
HIDDEN_SIZES
=
[
16
,
67
,
768
,
2048
,
5120
,
5137
,
8192
,
...
@@ -15,7 +15,11 @@ SEEDS = [0]
...
@@ -15,7 +15,11 @@ SEEDS = [0]
SCALE
=
[
0.1
,
0.5
,
0.8
,
1.2
,
2.1
]
SCALE
=
[
0.1
,
0.5
,
0.8
,
1.2
,
2.1
]
def
opcheck_int8_quant_static
(
output
,
input
,
scale
,
azp
=
None
):
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
def
opcheck_int8_quant_static
(
output
,
input
,
scale
,
azp
=
None
):
if
azp
is
None
:
if
azp
is
None
:
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
(
output
,
input
,
scale
,
None
))
(
output
,
input
,
scale
,
None
))
...
@@ -24,7 +28,7 @@ def opcheck_int8_quant_static(output, input, scale, azp=None):
...
@@ -24,7 +28,7 @@ def opcheck_int8_quant_static(output, input, scale, azp=None):
(
output
,
input
,
scale
,
azp
))
(
output
,
input
,
scale
,
azp
))
def
opcheck_int8_quant_dynamic
(
output
,
input
,
symmetric
=
True
):
def
opcheck_int8_quant_dynamic
(
output
,
input
,
symmetric
=
True
):
scale
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
scale
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
dtype
=
torch
.
float32
)
...
@@ -56,11 +60,18 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
...
@@ -56,11 +60,18 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
# kernel
# kernel
ops_out
,
ops_scales
,
_
=
scaled_int8_quant
(
x
)
ops_out
,
ops_scales
,
_
=
scaled_int8_quant
(
x
)
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
ops_scales
,
ref_scales
)
torch
.
allclose
(
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
ops_scales
,
ref_scales
)
torch
.
testing
.
assert_close
(
ops_scales
,
ref_scales
)
# big atol to account for rounding errors
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Currently, there is not supported on ROCm."
)
reason
=
"Currently, there is not supported on ROCm."
)
...
@@ -97,6 +108,11 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
...
@@ -97,6 +108,11 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
if
(
not
torch
.
allclose
(
scales_out
,
scales
)):
if
(
not
torch
.
allclose
(
scales_out
,
scales
)):
print
(
torch
.
argmax
(
torch
.
abs
(
scales_out
-
scales
)))
print
(
torch
.
argmax
(
torch
.
abs
(
scales_out
-
scales
)))
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
scales_out
,
scales
)
torch
.
allclose
(
azp_out
,
azps
,
atol
=
1
,
rtol
=
0.0
)
torch
.
allclose
(
ops_out
,
torch_out
,
atol
=
2
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
scales_out
,
scales
)
torch
.
testing
.
assert_close
(
scales_out
,
scales
)
# big atol to account for rounding errors
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
azp_out
,
azps
,
atol
=
1
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
azp_out
,
azps
,
atol
=
1
,
rtol
=
0.0
)
...
@@ -104,6 +120,8 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
...
@@ -104,6 +120,8 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
torch
.
testing
.
assert_close
(
ops_out
,
torch_out
,
atol
=
2
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
ops_out
,
torch_out
,
atol
=
2
,
rtol
=
0.0
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
,
False
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
,
False
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
@@ -125,10 +143,15 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
...
@@ -125,10 +143,15 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits
.
max
).
to
(
torch
.
int8
)
int8_traits
.
max
).
to
(
torch
.
int8
)
out2
,
_
,
_
=
scaled_int8_quant
(
x
,
scale_arg
)
out2
,
_
,
_
=
scaled_int8_quant
(
x
,
scale_arg
)
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
# big atol to account for rounding errors
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
@@ -155,10 +178,15 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
...
@@ -155,10 +178,15 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out2
,
x
,
scale_arg
,
azp_arg
)
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out2
,
x
,
scale_arg
,
azp_arg
)
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
# big atol to account for rounding errors
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
,
azp_arg
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
,
azp_arg
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"is_max"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"is_max"
,
[
True
,
False
])
...
@@ -190,4 +218,9 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
...
@@ -190,4 +218,9 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
out
=
torch
.
empty_like
(
expected
)
out
=
torch
.
empty_like
(
expected
)
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out
,
x
,
scale
,
azp
)
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out
,
x
,
scale
,
azp
)
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
expected
,
out
,
atol
=
0
,
rtol
=
0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
expected
,
out
,
atol
=
0
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
expected
,
out
,
atol
=
0
,
rtol
=
0
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
tests/kernels/test_layernorm.py
View file @
e150cf11
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
4096
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
4096
]
# Arbitrary values for testing
...
@@ -47,6 +47,14 @@ def test_rms_norm(
...
@@ -47,6 +47,14 @@ def test_rms_norm(
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.
# numerical errors than other operators because they involve reductions.
# Therefore, we use a larger tolerance.
# Therefore, we use a larger tolerance.
if
torch_version
.
startswith
(
"2.3"
):
if
add_residual
:
torch
.
allclose
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
allclose
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
if
add_residual
:
if
add_residual
:
torch
.
testing
.
assert_close
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
...
@@ -59,3 +67,5 @@ def test_rms_norm(
...
@@ -59,3 +67,5 @@ def test_rms_norm(
else
:
else
:
opcheck
(
torch
.
ops
.
_C
.
rms_norm
,
opcheck
(
torch
.
ops
.
_C
.
rms_norm
,
(
out
,
x
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
(
out
,
x
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
tests/kernels/test_moe.py
View file @
e150cf11
...
@@ -9,7 +9,6 @@ import torch
...
@@ -9,7 +9,6 @@ import torch
from
transformers
import
MixtralConfig
from
transformers
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
...
@@ -22,6 +21,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
...
@@ -22,6 +21,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
from
vllm.utils
import
is_hip
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
...
@@ -76,7 +77,12 @@ def test_fused_moe(
...
@@ -76,7 +77,12 @@ def test_fused_moe(
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
...
@@ -120,11 +126,18 @@ def test_mixtral_moe(dtype: torch.dtype):
...
@@ -120,11 +126,18 @@ def test_mixtral_moe(dtype: torch.dtype):
torch
.
float16
:
1e-3
,
torch
.
float16
:
1e-3
,
torch
.
bfloat16
:
1e-2
,
torch
.
bfloat16
:
1e-2
,
}
}
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
hf_states
.
flatten
(
0
,
1
),
vllm_states
,
rtol
=
mixtral_moe_tol
[
dtype
],
atol
=
mixtral_moe_tol
[
dtype
])
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
hf_states
.
flatten
(
0
,
1
),
torch
.
testing
.
assert_close
(
hf_states
.
flatten
(
0
,
1
),
vllm_states
,
vllm_states
,
rtol
=
mixtral_moe_tol
[
dtype
],
rtol
=
mixtral_moe_tol
[
dtype
],
atol
=
mixtral_moe_tol
[
dtype
])
atol
=
mixtral_moe_tol
[
dtype
])
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
def
stack_and_dev
(
tensors
:
List
[
torch
.
Tensor
]):
def
stack_and_dev
(
tensors
:
List
[
torch
.
Tensor
]):
...
@@ -137,6 +150,8 @@ def compute_max_diff(output, output_ref):
...
@@ -137,6 +150,8 @@ def compute_max_diff(output, output_ref):
torch
.
abs
(
output_ref
))
torch
.
abs
(
output_ref
))
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Currently, there is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
,
512
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
,
512
])
...
@@ -256,6 +271,8 @@ def test_fused_marlin_moe(
...
@@ -256,6 +271,8 @@ def test_fused_marlin_moe(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
a
.
device
)
device
=
a
.
device
)
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
_moe_C
.
topk_softmax
,
(
opcheck
(
torch
.
ops
.
_moe_C
.
topk_softmax
,
(
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
...
@@ -274,12 +291,16 @@ def test_fused_marlin_moe(
...
@@ -274,12 +291,16 @@ def test_fused_marlin_moe(
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
False
)
requires_grad
=
False
)
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
,
opcheck
(
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
,
(
a
,
qweight1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
(
a
,
qweight1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales1
,
g_idx1
,
sort_indices1
,
workspace
,
quant_type
,
m
,
scales1
,
g_idx1
,
sort_indices1
,
workspace
,
quant_type
,
m
,
2
*
n
,
k
,
True
,
e
,
topk
,
block_size_m
,
True
,
False
))
2
*
n
,
k
,
True
,
e
,
topk
,
block_size_m
,
True
,
False
))
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Currently, there is not supported on ROCm."
)
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
"don't run it in automated tests."
)
"don't run it in automated tests."
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
...
@@ -373,7 +394,8 @@ def test_moe_align_block_size_opcheck():
...
@@ -373,7 +394,8 @@ def test_moe_align_block_size_opcheck():
num_tokens_post_pad
=
torch
.
empty
((
1
),
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
device
=
topk_ids
.
device
)
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
_C
.
moe_align_block_size
,
opcheck
(
torch
.
ops
.
_C
.
moe_align_block_size
,
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
))
num_tokens_post_pad
))
tests/kernels/test_pos_encoding.py
View file @
e150cf11
...
@@ -8,6 +8,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -8,6 +8,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
.allclose_default
import
get_default_atol
,
get_default_rtol
from
.allclose_default
import
get_default_atol
,
get_default_rtol
from
.utils
import
torch_version
IS_NEOX_STYLE
=
[
True
,
False
]
IS_NEOX_STYLE
=
[
True
,
False
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
@@ -18,7 +19,7 @@ BATCH_SIZES = [1, 5] # Arbitrary values for testing
...
@@ -18,7 +19,7 @@ BATCH_SIZES = [1, 5] # Arbitrary values for testing
SEQ_LENS
=
[
11
,
8192
]
# Arbitrary values for testing
SEQ_LENS
=
[
11
,
8192
]
# Arbitrary values for testing
SEEDS
=
[
0
]
SEEDS
=
[
0
]
CUDA_DEVICES
=
[
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
1
)
]
]
...
@@ -67,6 +68,16 @@ def test_rotary_embedding(
...
@@ -67,6 +68,16 @@ def test_rotary_embedding(
ref_query
,
ref_key
=
rope
.
forward_native
(
positions
,
query
,
key
)
ref_query
,
ref_key
=
rope
.
forward_native
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
)
# Compare the results.
# Compare the results.
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
out_query
,
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
atol
=
get_default_atol
(
out_query
),
...
@@ -75,6 +86,8 @@ def test_rotary_embedding(
...
@@ -75,6 +86,8 @@ def test_rotary_embedding(
ref_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
rtol
=
get_default_rtol
(
out_key
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
@@ -126,6 +139,16 @@ def test_batched_rotary_embedding(
...
@@ -126,6 +139,16 @@ def test_batched_rotary_embedding(
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
))
device
=
device
))
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
elif
torch_version
.
startswith
(
"2.4"
):
# Compare the results.
# Compare the results.
torch
.
testing
.
assert_close
(
out_query
,
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
ref_query
,
...
@@ -135,6 +158,8 @@ def test_batched_rotary_embedding(
...
@@ -135,6 +158,8 @@ def test_batched_rotary_embedding(
ref_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
rtol
=
get_default_rtol
(
out_key
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
@@ -195,6 +220,16 @@ def test_batched_rotary_embedding_multi_lora(
...
@@ -195,6 +220,16 @@ def test_batched_rotary_embedding_multi_lora(
query_offsets
)
query_offsets
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
,
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
,
query_offsets
.
flatten
())
query_offsets
.
flatten
())
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
elif
torch_version
.
startswith
(
"2.4"
):
# Compare the results.
# Compare the results.
torch
.
testing
.
assert_close
(
out_query
,
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
ref_query
,
...
@@ -204,7 +239,8 @@ def test_batched_rotary_embedding_multi_lora(
...
@@ -204,7 +239,8 @@ def test_batched_rotary_embedding_multi_lora(
ref_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
rtol
=
get_default_rtol
(
out_key
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_rope_module_cache
():
def
test_rope_module_cache
():
...
...
tests/kernels/test_rotary_embedding.py
View file @
e150cf11
...
@@ -7,8 +7,11 @@ from typing import Optional
...
@@ -7,8 +7,11 @@ from typing import Optional
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
.utils
import
torch_version
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
def
rotary_embedding_opcheck
(
rot
,
def
rotary_embedding_opcheck
(
rot
,
...
@@ -30,6 +33,8 @@ def rotary_embedding_opcheck(rot,
...
@@ -30,6 +33,8 @@ def rotary_embedding_opcheck(rot,
rot
.
is_neox_style
))
rot
.
is_neox_style
))
@
pytest
.
mark
.
skipif
(
torch_version
.
startswith
(
"2.3"
),
reason
=
"Need torch2.4."
)
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"max_position"
,
[
11
,
4096
,
32768
])
@
pytest
.
mark
.
parametrize
(
"max_position"
,
[
11
,
4096
,
32768
])
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
[
True
,
False
])
...
...
tests/kernels/test_utils.py
View file @
e150cf11
...
@@ -5,14 +5,17 @@ Tests for miscellaneous utilities
...
@@ -5,14 +5,17 @@ Tests for miscellaneous utilities
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
.utils
import
torch_version
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
def
test_convert_fp8_opcheck
():
data
=
torch
.
randn
((
256
,
256
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
# def test_convert_fp8_opcheck():
result
=
torch
.
empty_like
(
data
,
dtype
=
torch
.
float8_e4m3fn
)
# data = torch.randn((256, 256), dtype=torch.float32, device="cuda")
opcheck
(
torch
.
ops
.
_C_cache_ops
.
convert_fp8
,
(
result
,
data
,
1.0
,
"fp8"
))
# result = torch.empty_like(data, dtype=torch.float8_e4m3fn)
# opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
...
...
tests/kernels/untest_permute_cols.py
View file @
e150cf11
...
@@ -3,13 +3,22 @@ import torch
...
@@ -3,13 +3,22 @@ import torch
from
tests.kernels.utils
import
opcheck
from
tests.kernels.utils
import
opcheck
from
vllm._custom_ops
import
permute_cols
from
vllm._custom_ops
import
permute_cols
from
.utils
import
torch_version
@
pytest
.
mark
.
parametrize
(
'shape'
,
[(
1
,
512
),
(
544
,
4096
),
(
67
,
8192
)])
@
pytest
.
mark
.
parametrize
(
'shape'
,
[(
1
,
512
),
(
544
,
4096
),
(
67
,
8192
)])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
bfloat16
,
torch
.
float16
])
def
test_permute_cols
(
shape
,
dtype
):
def
test_permute_cols
(
shape
,
dtype
):
if
torch_version
.
startswith
(
"2.3"
):
x
=
torch
.
randn
(
shape
,
dtype
=
dtype
).
cuda
()
perm
=
torch
.
randperm
(
x
.
shape
[
1
]).
to
(
torch
.
int
).
cuda
()
y
=
permute_cols
(
x
,
perm
)
torch
.
allclose
(
y
,
x
[:,
perm
])
elif
torch_version
.
startswith
(
"2.4"
):
x
=
torch
.
randn
(
shape
,
dtype
=
dtype
).
cuda
()
x
=
torch
.
randn
(
shape
,
dtype
=
dtype
).
cuda
()
perm
=
torch
.
randperm
(
x
.
shape
[
1
]).
to
(
torch
.
int
).
cuda
()
perm
=
torch
.
randperm
(
x
.
shape
[
1
]).
to
(
torch
.
int
).
cuda
()
opcheck
(
torch
.
ops
.
_C
.
permute_cols
,
(
x
,
perm
))
opcheck
(
torch
.
ops
.
_C
.
permute_cols
,
(
x
,
perm
))
y
=
permute_cols
(
x
,
perm
)
y
=
permute_cols
(
x
,
perm
)
torch
.
testing
.
assert_close
(
y
,
x
[:,
perm
])
torch
.
testing
.
assert_close
(
y
,
x
[:,
perm
])
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
\ No newline at end of file
tests/kernels/utils.py
View file @
e150cf11
...
@@ -30,6 +30,8 @@ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
...
@@ -30,6 +30,8 @@ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
"test_aot_dispatch_dynamic"
,
"test_aot_dispatch_dynamic"
,
)
)
torch_version
=
torch
.
__version__
class
QKVInputs
(
NamedTuple
):
class
QKVInputs
(
NamedTuple
):
'''
'''
...
@@ -974,9 +976,10 @@ def fp8_allclose(
...
@@ -974,9 +976,10 @@ def fp8_allclose(
equal_nan
=
equal_nan
)).
item
())
equal_nan
=
equal_nan
)).
item
())
# A special version of op check that has a restricted default set of test_utils
if
torch_version
.
startswith
(
"2.4"
):
# and a patched version of allclose that supports fp8 types.
# A special version of op check that has a restricted default set of test_utils
def
opcheck
(
op
:
Union
[
torch
.
_ops
.
OpOverload
,
torch
.
_ops
.
OpOverloadPacket
,
# and a patched version of allclose that supports fp8 types.
def
opcheck
(
op
:
Union
[
torch
.
_ops
.
OpOverload
,
torch
.
_ops
.
OpOverloadPacket
,
torch
.
_library
.
custom_ops
.
CustomOpDef
],
torch
.
_library
.
custom_ops
.
CustomOpDef
],
args
:
Tuple
[
Any
,
...],
args
:
Tuple
[
Any
,
...],
kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment