Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
217ee621
Commit
217ee621
authored
Dec 05, 2024
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.6.2-dev' into v0.6.2-dev
parents
f0021a4d
3f78216a
Changes
68
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
782 additions
and
528 deletions
+782
-528
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+3
-5
tests/kernels/test_awq_triton.py
tests/kernels/test_awq_triton.py
+6
-1
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+254
-148
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+259
-230
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+2
-2
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+35
-31
tests/kernels/test_int8_quant.py
tests/kernels/test_int8_quant.py
+81
-45
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+21
-11
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+43
-21
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+64
-28
tests/kernels/test_rotary_embedding.py
tests/kernels/test_rotary_embedding.py
+6
-1
tests/kernels/test_utils.py
tests/kernels/test_utils.py
+8
-5
tests/kernels/untest_aqlm.py
tests/kernels/untest_aqlm.py
+0
-0
tests/kernels/untest_awq.py
tests/kernels/untest_awq.py
+0
-0
tests/kernels/untest_causal_conv1d.py
tests/kernels/untest_causal_conv1d.py
+0
-0
tests/kernels/untest_flashinfer.py
tests/kernels/untest_flashinfer.py
+0
-0
tests/kernels/untest_fp8_quant.py
tests/kernels/untest_fp8_quant.py
+0
-0
tests/kernels/untest_ggml.py
tests/kernels/untest_ggml.py
+0
-0
tests/kernels/untest_gguf.py
tests/kernels/untest_gguf.py
+0
-0
tests/kernels/untest_gptq.py
tests/kernels/untest_gptq.py
+0
-0
No files found.
tests/kernels/test_attention_selector.py
View file @
217ee621
...
@@ -6,14 +6,12 @@ import torch
...
@@ -6,14 +6,12 @@ import torch
from
tests.kernels.utils
import
override_backend_env_variable
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm.attention.selector
import
which_attn_to_use
from
vllm.attention.selector
import
which_attn_to_use
from
vllm.utils
import
STR_FLASH_ATTN_VAL
,
STR_INVALID_VAL
from
vllm.utils
import
STR_FLASH_ATTN_VAL
,
STR_INVALID_VAL
from
vllm.utils
import
is_hip
# @pytest.mark.parametrize(
# "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
# @pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
])
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
,
"OPENVINO"
]
if
not
is_hip
()
else
[
"ROCM_FLASH"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"openvino"
,
"hip"
,
"cuda"
])
def
test_env
(
name
:
str
,
device
:
str
,
monkeypatch
):
def
test_env
(
name
:
str
,
device
:
str
,
monkeypatch
):
"""Test that the attention selector can be set via environment variable.
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
Note that we do not test FlashAttn because it is the default backend.
...
...
tests/kernels/test_awq_triton.py
View file @
217ee621
...
@@ -8,8 +8,9 @@ import torch
...
@@ -8,8 +8,9 @@ import torch
from
vllm.model_executor.layers.quantization.awq_triton
import
(
from
vllm.model_executor.layers.quantization.awq_triton
import
(
AWQ_TRITON_SUPPORTED_GROUP_SIZES
,
awq_dequantize_triton
,
awq_gemm_triton
)
AWQ_TRITON_SUPPORTED_GROUP_SIZES
,
awq_dequantize_triton
,
awq_gemm_triton
)
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
device
=
"cuda"
device
=
"cuda"
def
reverse_awq_order
(
t
:
torch
.
Tensor
):
def
reverse_awq_order
(
t
:
torch
.
Tensor
):
...
@@ -64,6 +65,8 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
...
@@ -64,6 +65,8 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
# qweights - [R , C // 8], int32
# qweights - [R , C // 8], int32
# scales - [R // G, C ], float16
# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
# zeros - [R // G, C // 8], int32
@
pytest
.
mark
.
skipif
(
torch_version
.
startswith
(
"2.3"
),
reason
=
"Need triton3.0."
)
@
pytest
.
mark
.
parametrize
(
"qweight_rows"
,
[
3584
,
18944
,
128
,
256
,
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"qweight_rows"
,
[
3584
,
18944
,
128
,
256
,
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"qweight_cols"
,
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"qweight_cols"
,
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
AWQ_TRITON_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
AWQ_TRITON_SUPPORTED_GROUP_SIZES
)
...
@@ -111,6 +114,8 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
...
@@ -111,6 +114,8 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
# qweight - [K, M // 8]
# qweight - [K, M // 8]
# qzeros - [K // G, M // 8]
# qzeros - [K // G, M // 8]
# scales - [K // G, M]
# scales - [K // G, M]
@
pytest
.
mark
.
skipif
(
torch_version
.
startswith
(
"2.3"
),
reason
=
"Need triton3.0."
)
@
pytest
.
mark
.
parametrize
(
"N"
,
[
1
,
2
,
4
,
8
,
14
,
17
,
23
,
32
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
1
,
2
,
4
,
8
,
14
,
17
,
23
,
32
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"M"
,
[
16
,
24
,
32
])
@
pytest
.
mark
.
parametrize
(
"M"
,
[
16
,
24
,
32
])
...
...
tests/kernels/test_cache.py
View file @
217ee621
...
@@ -4,9 +4,11 @@ from typing import List, Tuple
...
@@ -4,9 +4,11 @@ from typing import List, Tuple
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
,
opcheck
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
vllm.utils
import
is_hip
from
.utils
import
torch_version
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
@@ -87,26 +89,45 @@ def test_copy_blocks(
...
@@ -87,26 +89,45 @@ def test_copy_blocks(
block_mapping_tensor
=
torch
.
tensor
(
block_mapping
,
block_mapping_tensor
=
torch
.
tensor
(
block_mapping
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
device
).
view
(
-
1
,
2
)
device
=
device
).
view
(
-
1
,
2
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
copy_blocks
,
if
torch_version
.
startswith
(
"2.3"
):
(
key_caches
,
value_caches
,
block_mapping_tensor
),
ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping_tensor
)
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
for
src
,
dst
in
block_mapping
:
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
for
cloned_key_cache
in
cloned_key_caches
:
ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping_tensor
)
cloned_key_cache
[
dst
].
copy_
(
cloned_key_cache
[
src
])
for
cloned_value_cache
in
cloned_value_caches
:
# Run the reference implementation.
cloned_value_cache
[
dst
].
copy_
(
cloned_value_cache
[
src
])
for
src
,
dst
in
block_mapping
:
for
cloned_key_cache
in
cloned_key_caches
:
# Compare the results.
cloned_key_cache
[
dst
].
copy_
(
cloned_key_cache
[
src
])
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
for
cloned_value_cache
in
cloned_value_caches
:
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
cloned_value_cache
[
dst
].
copy_
(
cloned_value_cache
[
src
])
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
cloned_value_caches
):
# Compare the results.
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
elif
torch_version
.
startswith
(
"2.4"
):
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
from
tests.kernels.utils
import
opcheck
cloned_value_caches
):
opcheck
(
torch
.
ops
.
_C_cache_ops
.
copy_blocks
,
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
(
key_caches
,
value_caches
,
block_mapping_tensor
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping_tensor
)
# Run the reference implementation.
for
src
,
dst
in
block_mapping
:
for
cloned_key_cache
in
cloned_key_caches
:
cloned_key_cache
[
dst
].
copy_
(
cloned_key_cache
[
src
])
for
cloned_value_cache
in
cloned_value_caches
:
cloned_value_cache
[
dst
].
copy_
(
cloned_value_cache
[
src
])
# Compare the results.
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
cloned_value_caches
):
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
@@ -162,46 +183,87 @@ def test_reshape_and_cache(
...
@@ -162,46 +183,87 @@ def test_reshape_and_cache(
# Using default kv_scale
# Using default kv_scale
k_scale
=
v_scale
=
1.0
k_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
if
torch_version
.
startswith
(
"2.3"
):
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
k_scale
,
v_scale
),
kv_cache_dtype
,
k_scale
,
v_scale
)
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
if
kv_cache_dtype
==
"fp8"
:
kv_cache_dtype
,
k_scale
,
v_scale
)
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
if
kv_cache_dtype
==
"fp8"
:
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
# Run the reference implementation.
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
# Run the reference implementation.
block_indicies
=
block_indicies
.
cpu
().
tolist
()
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_offsets
=
slot_mapping
%
block_size
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_offsets
=
block_offsets
.
cpu
().
tolist
()
block_indicies_lst
=
block_indicies
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_offsets
=
slot_mapping
%
block_size
block_idx
=
block_indicies
[
i
]
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
block_offset
=
block_offsets
[
i
]
for
i
in
range
(
num_tokens
):
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
block_idx
=
block_indicies_lst
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
block_offset
=
block_offsets_lst
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
if
kv_cache_dtype
==
"fp8"
:
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
assert
torch
.
allclose
(
result_key_cache
,
cloned_key_cache
,
if
kv_cache_dtype
==
"fp8"
:
atol
=
0.001
,
torch
.
testing
.
assert_close
(
result_key_cache
,
rtol
=
0.1
)
cloned_key_cache
,
assert
torch
.
allclose
(
result_value_cache
,
atol
=
0.001
,
cloned_value_cache
,
rtol
=
0.1
)
atol
=
0.001
,
torch
.
testing
.
assert_close
(
result_value_cache
,
rtol
=
0.1
)
cloned_value_cache
,
else
:
atol
=
0.001
,
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
rtol
=
0.1
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indicies_lst
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
cloned_key_cache
[
block_idx
,
:,
:,
block_offset
,
:]
=
reshaped_key
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
torch
.
testing
.
assert_close
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
torch
.
testing
.
assert_close
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
else
:
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
else
:
else
:
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
...
@@ -272,43 +334,69 @@ def test_reshape_and_cache_flash(
...
@@ -272,43 +334,69 @@ def test_reshape_and_cache_flash(
# Using default kv_scale
# Using default kv_scale
k_scale
=
v_scale
=
1.0
k_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
if
torch_version
.
startswith
(
"2.3"
):
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
# Clone the KV caches.
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
cloned_key_cache
=
key_cache
.
clone
()
k_scale
,
v_scale
),
cloned_value_cache
=
value_cache
.
clone
()
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
# Run the reference implementation.
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indicies_lst
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
# Call the reshape_and_cache kernel.
torch
.
testing
.
assert_close
(
result_key_cache
,
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
cloned_key_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
atol
=
0.001
,
rtol
=
0.1
)
# Run the reference implementation.
torch
.
testing
.
assert_close
(
result_value_cache
,
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
'floor'
)
cloned_value_cache
,
block_indicies
=
block_indicies
.
cpu
().
tolist
()
atol
=
0.001
,
block_offsets
=
slot_mapping
%
block_size
rtol
=
0.1
)
block_offsets
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies
[
i
]
block_offset
=
block_offsets
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
# Run the reference implementation.
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indicies_lst
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets_lst
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies_lst
[
i
]
block_offset
=
block_offsets_lst
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
torch
.
testing
.
assert_close
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
torch
.
testing
.
assert_close
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
else
:
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
else
:
else
:
torch
.
testing
.
assert_close
(
key_cache
,
cloned_key_cache
)
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
torch
.
testing
.
assert_close
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
...
@@ -371,56 +459,74 @@ def test_swap_blocks(
...
@@ -371,56 +459,74 @@ def test_swap_blocks(
src_key_caches_clone
=
src_key_caches
[
0
].
clone
()
src_key_caches_clone
=
src_key_caches
[
0
].
clone
()
src_value_caches_clone
=
src_value_caches
[
0
].
clone
()
src_value_caches_clone
=
src_value_caches
[
0
].
clone
()
# Call the swap_blocks kernel.
if
torch_version
.
startswith
(
"2.3"
):
do_opcheck
=
(
head_size
==
HEAD_SIZES
[
0
])
# Call the swap_blocks kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping_tensor
),
block_mapping_tensor
)
cond
=
do_opcheck
)
ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
block_mapping_tensor
)
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping_tensor
),
cond
=
do_opcheck
)
for
src
,
dst
in
block_mapping
:
assert
torch
.
allclose
(
src_key_caches_clone
[
src
].
cpu
(),
ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
dist_key_caches
[
0
][
dst
].
cpu
())
block_mapping_tensor
)
assert
torch
.
allclose
(
src_value_caches_clone
[
src
].
cpu
(),
ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
dist_value_caches
[
0
][
dst
].
cpu
())
block_mapping_tensor
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
for
src
,
dst
in
block_mapping
:
# Call the swap_blocks kernel.
torch
.
testing
.
assert_close
(
src_key_caches_clone
[
src
].
cpu
(),
do_opcheck
=
(
head_size
==
HEAD_SIZES
[
0
])
dist_key_caches
[
0
][
dst
].
cpu
())
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
torch
.
testing
.
assert_close
(
src_value_caches_clone
[
src
].
cpu
(),
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping_tensor
),
dist_value_caches
[
0
][
dst
].
cpu
())
cond
=
do_opcheck
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping_tensor
),
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
cond
=
do_opcheck
)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("block_size", BLOCK_SIZES)
ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
block_mapping_tensor
)
# @pytest.mark.parametrize("dtype", DTYPES)
ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
# @pytest.mark.parametrize("seed", SEEDS)
block_mapping_tensor
)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @torch.inference_mode()
for
src
,
dst
in
block_mapping
:
# def test_fp8_e4m3_conversion(
torch
.
testing
.
assert_close
(
src_key_caches_clone
[
src
].
cpu
(),
# num_heads: int,
dist_key_caches
[
0
][
dst
].
cpu
())
# head_size: int,
torch
.
testing
.
assert_close
(
src_value_caches_clone
[
src
].
cpu
(),
# block_size: int,
dist_value_caches
[
0
][
dst
].
cpu
())
# num_blocks: int,
else
:
# dtype: torch.dtype,
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
# seed: int,
# device: str,
# ) -> None:
@
pytest
.
mark
.
skipif
(
is_hip
(),
# seed_everything(seed)
reason
=
"FP8 is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
# low = -224.0
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
# high = 224.0
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
# shape = (num_blocks, num_heads, head_size, block_size)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
# cache = torch.empty(shape, dtype=dtype, device=device)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
# cache.uniform_(low, high)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
# cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
@
torch
.
inference_mode
()
# ops.convert_fp8(cache_fp8, cache)
def
test_fp8_e4m3_conversion
(
num_heads
:
int
,
# converted_cache = torch.empty_like(cache)
head_size
:
int
,
# ops.convert_fp8(converted_cache, cache_fp8)
block_size
:
int
,
num_blocks
:
int
,
# torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
)
->
None
:
seed_everything
(
seed
)
low
=
-
224.0
high
=
224.0
shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
cache
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
)
cache
.
uniform_
(
low
,
high
)
cache_fp8
=
torch
.
empty_like
(
cache
,
dtype
=
torch
.
uint8
)
ops
.
convert_fp8
(
cache_fp8
,
cache
)
converted_cache
=
torch
.
empty_like
(
cache
)
ops
.
convert_fp8
(
converted_cache
,
cache_fp8
)
torch
.
testing
.
assert_close
(
cache
,
converted_cache
,
atol
=
0.001
,
rtol
=
0.1
)
tests/kernels/test_cutlass.py
View file @
217ee621
...
@@ -7,9 +7,9 @@ from typing import Optional, Type
...
@@ -7,9 +7,9 @@ from typing import Optional, Type
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
.utils
import
torch_version
CUDA_DEVICES
=
[
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
...
@@ -39,7 +39,7 @@ def baseline_scaled_mm(a: torch.Tensor,
...
@@ -39,7 +39,7 @@ def baseline_scaled_mm(a: torch.Tensor,
scale_b
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
],
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
output
=
(
scale_a
*
(
scale_b
*
(
torch
.
mm
(
output
=
(
scale_a
*
(
scale_b
.
T
*
(
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))))).
to
(
out_dtype
)
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))))).
to
(
out_dtype
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
output
=
output
+
bias
...
@@ -75,10 +75,16 @@ def cutlass_fp8_gemm_helper(m: int,
...
@@ -75,10 +75,16 @@ def cutlass_fp8_gemm_helper(m: int,
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
5e-2
)
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
1e-1
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
5e-2
)
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
def
cutlass_int8_gemm_helper
(
m
:
int
,
def
cutlass_int8_gemm_helper
(
m
:
int
,
...
@@ -99,7 +105,7 @@ def cutlass_int8_gemm_helper(m: int,
...
@@ -99,7 +105,7 @@ def cutlass_int8_gemm_helper(m: int,
scale_a
=
(
torch
.
randn
((
m_a_scales
,
1
),
device
=
device
,
scale_a
=
(
torch
.
randn
((
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
))
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
((
1
,
n_b_scales
),
device
=
device
,
scale_b
=
(
torch
.
randn
((
n_b_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
))
dtype
=
torch
.
float32
))
if
use_bias
:
if
use_bias
:
...
@@ -107,42 +113,53 @@ def cutlass_int8_gemm_helper(m: int,
...
@@ -107,42 +113,53 @@ def cutlass_int8_gemm_helper(m: int,
else
:
else
:
bias
=
None
bias
=
None
b
=
b
.
contiguous
().
reshape
(
k
,
-
1
)
# print("a.shape:",a.shape)
# print("b.shape:",b.shape)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
# print("out:",out[0:5][0:5])
# print("baseline:",baseline[0:5][0:5])
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
100
,
33
])
#
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
4096
,
8192
,
16384
,
24576
,
256
,
1024
])
#
@pytest.mark.parametrize("n", [2048, 4096, 8192, 16384, 24576, 256, 1024])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
#
@pytest.mark.parametrize("k", [128, 496, 1024])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
#
@pytest.mark.parametrize("per_act_token", [True, False])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
#
@pytest.mark.parametrize("use_bias", [True, False])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
#
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason
=
"FP8 is not supported on this GPU type."
)
#
reason="FP8 is not supported on this GPU type.")
def
test_cutlass_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
#
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch
:
bool
,
use_bias
:
bool
):
#
per_out_ch: bool, use_bias: bool):
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_bias
)
#
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
8192
,
16384
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
8192
,
16384
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
def
test_cutlass_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
per_out_ch
:
bool
,
use_bias
:
bool
):
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_bias
)
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
b
float16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
float16
])
#
torch.
b
float16
,
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
def
test_cutlass_int8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
],
out_dtype
:
Type
[
torch
.
dtype
],
...
@@ -156,50 +173,50 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
...
@@ -156,50 +173,50 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype
=
out_dtype
)
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
#
@pytest.mark.parametrize("per_act_token", [True, False])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
#
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
#
@pytest.mark.parametrize("use_bias", [True, False])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
#
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason
=
"FP8 is not supported on this GPU type."
)
#
reason="FP8 is not supported on this GPU type.")
def
test_cutlass_fp8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
#
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype
:
Type
[
torch
.
dtype
],
#
out_dtype: Type[torch.dtype],
use_bias
:
bool
):
#
use_bias: bool):
cutlass_fp8_gemm_helper
(
512
,
#
cutlass_fp8_gemm_helper(512,
512
,
#
512,
512
,
#
512,
per_act_token
,
#
per_act_token,
per_out_ch
,
#
per_out_ch,
use_bias
,
#
use_bias,
out_dtype
=
out_dtype
)
#
out_dtype=out_dtype)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
#
@pytest.mark.parametrize("per_act_token", [True, False])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
#
@pytest.mark.parametrize("use_bias", [True, False])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
#
@pytest.mark.parametrize("device", CUDA_DEVICES)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
#
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason
=
"FP8 is not supported on this GPU type."
)
#
reason="FP8 is not supported on this GPU type.")
def
test_cutlass_fp8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
#
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
use_bias
:
bool
,
device
:
str
):
#
use_bias: bool, device: str):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
#
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
torch
.
bfloat16
,
device
)
#
torch.bfloat16, device)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
#
@pytest.mark.parametrize("per_act_token", [True, False])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
#
@pytest.mark.parametrize("use_bias", [True, False])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
#
@pytest.mark.parametrize("device", CUDA_DEVICES)
def
test_cutlass_int8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
#
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
use_bias
:
bool
,
device
:
str
):
#
use_bias: bool, device: str):
cutlass_int8_gemm_helper
(
512
,
#
cutlass_int8_gemm_helper(512,
512
,
#
512,
512
,
#
512,
per_act_token
,
#
per_act_token,
per_out_ch
,
#
per_out_ch,
use_bias
,
#
use_bias,
out_dtype
=
torch
.
bfloat16
,
#
out_dtype=torch.bfloat16,
device
=
device
)
#
device=device)
# For the following two tests:
# For the following two tests:
...
@@ -207,155 +224,162 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
...
@@ -207,155 +224,162 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# of a large power of two. In any case, the kernel will have a naive fallback
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
# kernel must handle any M thrown at it.
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
# @pytest.mark.parametrize("per_act_token", [True, False])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
# @pytest.mark.parametrize("per_out_ch", [True, False])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
# @pytest.mark.parametrize("use_bias", [True, False])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
# @pytest.mark.skipif(not current_platform.has_device_capability(89),
reason
=
"FP8 is not supported on this GPU type."
)
# reason="FP8 is not supported on this GPU type.")
def
test_cutlass_fp8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
# def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
use_bias
:
bool
):
# use_bias: bool):
for
nk
in
range
(
32
,
128
,
32
):
# for nk in range(32, 128, 32):
for
m
in
range
(
1
,
128
):
# for m in range(1, 128):
cutlass_fp8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
# cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
use_bias
)
# use_bias)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
# @pytest.mark.parametrize("per_act_token", [True, False])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
# @pytest.mark.parametrize("per_out_ch", [True, False])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
# @pytest.mark.parametrize("use_bias", [True, False])
def
test_cutlass_int8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
# def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
use_bias
:
bool
):
# use_bias: bool):
for
nk
in
range
(
32
,
128
,
32
):
# for nk in range(32, 128, 32):
for
m
in
range
(
1
,
128
):
# for m in range(1, 128):
cutlass_int8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
# cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
use_bias
)
# use_bias)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
32
,
64
,
128
])
# @pytest.mark.parametrize("m", [32, 64, 128])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
16
,
32
,
64
])
# @pytest.mark.parametrize("n", [16, 32, 64])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
64
,
128
,
256
])
# @pytest.mark.parametrize("k", [64, 128, 256])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@
pytest
.
mark
.
skip
# @pytest.mark.skip
def
test_cutlass_int8_azp_bias_fold
(
m
:
int
,
n
:
int
,
k
:
int
,
# def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
out_dtype
:
torch
.
dtype
):
# out_dtype: torch.dtype):
# Currently, the test is failing because folding azp into
# # Currently, the test is failing because folding azp into
# 16-bit bias loses too much precision
# # 16-bit bias loses too much precision
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
# scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b
=
torch
.
randn
((
1
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
# scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
aq_i8
=
rand_int8
((
m
,
k
))
# aq_i8 = rand_int8((m, k))
bq_i8
=
rand_int8
((
n
,
k
)).
t
()
# bq_i8 = rand_int8((n, k)).t()
aq_i32
=
aq_i8
.
to
(
dtype
=
torch
.
int32
)
# aq_i32 = aq_i8.to(dtype=torch.int32)
bq_i32
=
bq_i8
.
to
(
dtype
=
torch
.
int32
)
# bq_i32 = bq_i8.to(dtype=torch.int32)
aq_f32
=
aq_i8
.
to
(
dtype
=
torch
.
float32
)
# aq_f32 = aq_i8.to(dtype=torch.float32)
bq_f32
=
bq_i8
.
to
(
dtype
=
torch
.
float32
)
# bq_f32 = bq_i8.to(dtype=torch.float32)
b_dq
=
scale_b
*
bq_f32
# b_dq = scale_b * bq_f32
azp_a
=
torch
.
rand
((
1
,
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
10
+
1.5
# azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
azp_aq_i8
=
(
azp_a
/
scale_a
).
to
(
dtype
=
torch
.
int8
)
# azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
azp_a
=
azp_aq_i8
.
to
(
dtype
=
torch
.
float32
)
*
scale_a
# correct for rounding
# azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
a_dq
=
scale_a
*
(
aq_i32
+
azp_aq_i8
).
to
(
dtype
=
torch
.
float32
)
# a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
torch
.
testing
.
assert_close
(
a_dq
,
scale_a
*
aq_f32
+
azp_a
)
# torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
baseline_dq
=
torch
.
mm
(
a_dq
,
b_dq
).
to
(
out_dtype
)
# baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
J
=
torch
.
ones
((
1
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
# J = torch.ones((1, k), device="cuda", dtype=torch.float32)
azp_bias
=
(
azp_a
*
scale_b
*
(
J
@
bq_f32
)).
to
(
out_dtype
)
# azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
assert
azp_bias
.
shape
==
(
1
,
n
)
# assert azp_bias.shape == (1, n)
assert
azp_bias
[
0
,
:].
shape
==
(
n
,
)
# assert azp_bias[0, :].shape == (n, )
baseline_q
=
(
scale_a
.
to
(
device
=
'cpu'
)
*
scale_b
.
to
(
device
=
'cpu'
)
*
(
# baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
(
aq_i32
+
azp_aq_i8
).
to
(
device
=
'cpu'
)
@
bq_i32
.
to
(
device
=
'cpu'
))).
to
(
# (aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
dtype
=
out_dtype
,
device
=
'cuda'
)
# dtype=out_dtype, device='cuda')
out
=
ops
.
cutlass_scaled_mm
(
aq_i8
,
# out = ops.cutlass_scaled_mm(aq_i8,
bq_i8
,
# bq_i8,
scale_a
,
# scale_a,
scale_b
,
# scale_b,
out_dtype
=
out_dtype
,
# out_dtype=out_dtype,
bias
=
azp_bias
[
0
,
:])
# bias=azp_bias[0, :])
torch
.
testing
.
assert_close
(
out
,
baseline_dq
,
rtol
=
1e-2
,
atol
=
1e0
)
# torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
torch
.
testing
.
assert_close
(
out
,
baseline_q
,
rtol
=
1e-2
,
atol
=
1e0
)
# torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
32
,
64
,
128
])
# @pytest.mark.parametrize("m", [32, 64, 128])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
16
,
32
,
64
])
# @pytest.mark.parametrize("n", [16, 32, 64])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
64
,
128
,
256
])
# @pytest.mark.parametrize("k", [64, 128, 256])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
# @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
# @pytest.mark.parametrize("use_bias", [True, False])
@
pytest
.
mark
.
parametrize
(
"azp_per_token"
,
[
True
,
False
])
# @pytest.mark.parametrize("azp_per_token", [True, False])
def
test_cutlass_int8_azp
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
,
# def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
use_bias
:
bool
,
azp_per_token
:
bool
):
# use_bias: bool, azp_per_token: bool):
m_azp
=
m
if
azp_per_token
else
1
# m_azp = m if azp_per_token else 1
scale_a
=
torch
.
randn
((
m_azp
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
# scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
scale_b
=
torch
.
randn
((
1
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
# scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
aq_i8
=
rand_int8
((
m
,
k
))
# aq_i8 = rand_int8((m, k))
aq_i32
=
aq_i8
.
to
(
dtype
=
torch
.
int32
)
# aq_i32 = aq_i8.to(dtype=torch.int32)
aq_f32
=
aq_i8
.
to
(
dtype
=
torch
.
float32
)
# aq_f32 = aq_i8.to(dtype=torch.float32)
bq_i8
=
rand_int8
((
n
,
k
)).
t
()
# bq_i8 = rand_int8((n, k)).t()
bq_i32
=
bq_i8
.
to
(
dtype
=
torch
.
int32
)
# bq_i32 = bq_i8.to(dtype=torch.int32)
bq_f32
=
bq_i8
.
to
(
dtype
=
torch
.
float32
)
# bq_f32 = bq_i8.to(dtype=torch.float32)
b_dq
=
scale_b
*
bq_f32
# b_dq = scale_b * bq_f32
azp_a
=
torch
.
rand
(
# azp_a = torch.rand(
(
m_azp
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
10
+
1.5
# (m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
azp_aq_i8
=
(
azp_a
/
scale_a
).
to
(
dtype
=
torch
.
int8
)
# azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
azp_a
=
azp_aq_i8
.
to
(
dtype
=
torch
.
float32
)
*
scale_a
# correct for rounding
# azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
a_dq
=
scale_a
*
(
aq_i32
-
azp_aq_i8
).
to
(
dtype
=
torch
.
float32
)
# a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
torch
.
testing
.
assert_close
(
a_dq
,
# torch.testing.assert_close(a_dq,
scale_a
*
aq_f32
-
azp_a
,
# scale_a * aq_f32 - azp_a,
rtol
=
1e-4
,
# rtol=1e-4,
atol
=
1e-3
)
# atol=1e-3)
if
use_bias
:
# if use_bias:
bias
=
torch
.
rand
((
1
,
n
),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
10
+
2.5
# bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
else
:
# else:
bias
=
torch
.
zeros
((
1
,
n
),
device
=
"cuda"
,
dtype
=
out_dtype
)
# bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
baseline_dq
=
(
torch
.
mm
(
a_dq
,
b_dq
)
+
bias
).
to
(
out_dtype
)
# baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
# int32 mm not supported on CUDA
# # int32 mm not supported on CUDA
a_noazp_i32_cpu
=
(
aq_i32
-
azp_aq_i8
).
to
(
device
=
'cpu'
)
# a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
cq
=
(
a_noazp_i32_cpu
@
bq_i32
.
to
(
device
=
'cpu'
)).
to
(
device
=
'cuda'
)
# cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
baseline_q
=
(
scale_a
*
scale_b
*
cq
+
bias
).
to
(
dtype
=
out_dtype
)
# baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
# Hadamard is just the sum of the cols
# # Hadamard is just the sum of the cols
azp_adj_i32
=
bq_i32
.
sum
(
dim
=
0
,
keepdim
=
True
,
dtype
=
torch
.
int32
)
# azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
azp_i32
=
azp_aq_i8
.
to
(
dtype
=
torch
.
int32
)
# azp_i32 = azp_aq_i8.to(dtype=torch.int32)
func_bias
=
bias
if
use_bias
else
None
# func_bias = bias if use_bias else None
if
azp_per_token
:
# if azp_per_token:
out
=
ops
.
cutlass_scaled_mm_azp
(
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
# out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
out_dtype
,
azp_adj_i32
,
azp_i32
,
# out_dtype, azp_adj_i32, azp_i32,
func_bias
)
# func_bias)
else
:
# else:
azp_with_adj_i32
=
azp_i32
*
azp_adj_i32
# azp_with_adj_i32 = azp_i32 * azp_adj_i32
out
=
ops
.
cutlass_scaled_mm_azp
(
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
# out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
out_dtype
,
azp_with_adj_i32
,
None
,
# out_dtype, azp_with_adj_i32, None,
func_bias
)
# func_bias)
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
# # bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
# # float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
rtol
=
1e-2
if
out_dtype
==
torch
.
bfloat16
else
1e-3
# rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
atol
=
1e-3
# atol = 1e-3
torch
.
testing
.
assert_close
(
out
,
baseline_dq
,
rtol
=
rtol
,
atol
=
atol
)
# if torch_version.startswith("2.3"):
torch
.
testing
.
assert_close
(
out
,
baseline_q
,
rtol
=
rtol
,
atol
=
atol
)
# assert torch.allclose(out, baseline_dq, rtol=rtol, atol=atol)
# assert torch.allclose(out, baseline_q, rtol=rtol, atol=atol)
if
azp_per_token
:
# elif torch_version.startswith("2.4"):
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm_azp
,
# from tests.kernels.utils import opcheck
(
out
,
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
azp_adj_i32
,
azp_i32
,
# torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
func_bias
))
# torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
else
:
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm_azp
,
# if azp_per_token:
(
out
,
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
azp_with_adj_i32
,
None
,
# opcheck(torch.ops._C.cutlass_scaled_mm_azp,
func_bias
))
# (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
# func_bias))
# else:
# opcheck(torch.ops._C.cutlass_scaled_mm_azp,
# (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
# func_bias))
# else:
# print(f"PyTorch version {torch_version} is not specifically handled.")
# Test working with a subset of A and B
# Test working with a subset of A and B
...
@@ -367,7 +391,11 @@ def test_cutlass_subset():
...
@@ -367,7 +391,11 @@ def test_cutlass_subset():
whole_b
=
to_int8
(
torch
.
randn
((
big_n
,
big_k
),
device
=
"cuda"
).
t
()
*
5
)
whole_b
=
to_int8
(
torch
.
randn
((
big_n
,
big_k
),
device
=
"cuda"
).
t
()
*
5
)
a
=
whole_a
[
0
:
m
,
0
:
k
]
a
=
whole_a
[
0
:
m
,
0
:
k
]
b
=
whole_b
[
0
:
k
,
0
:
n
]
b
=
whole_b
[
0
:
k
,
0
:
n
]
#变成连续内存,矩阵子模块目前不支持计算,需要重新计算lda
a
=
a
.
contiguous
().
reshape
(
m
,
-
1
)
b
=
b
.
contiguous
().
reshape
(
k
,
-
1
)
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
...
@@ -399,25 +427,26 @@ class CutlassLayer(torch.nn.Module):
...
@@ -399,25 +427,26 @@ class CutlassLayer(torch.nn.Module):
return
ops
.
cutlass_scaled_mm
(
a
,
self
.
b
,
self
.
scale_a
,
self
.
scale_b
,
return
ops
.
cutlass_scaled_mm
(
a
,
self
.
b
,
self
.
scale_a
,
self
.
scale_b
,
self
.
out_dtype
)
self
.
out_dtype
)
#目前只支持per-act-token+per-out-ch(fp16)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
])
def
test_cutlass_cuda_graph
(
per_act_token
:
bool
,
per_out_ch
:
bool
):
def
test_cutlass_cuda_graph
(
per_act_token
:
bool
,
per_out_ch
:
bool
):
m
,
n
,
k
=
512
,
512
,
512
m
,
n
,
k
=
512
,
512
,
512
a
=
to_int8
(
torch
.
randn
((
m
,
k
),
device
=
"cuda"
))
a
=
to_int8
(
torch
.
randn
((
m
,
k
),
device
=
"cuda"
))
b
=
to_int8
(
torch
.
randn
((
n
,
k
),
device
=
"cuda"
).
t
())
b
=
to_int8
(
torch
.
randn
((
n
,
k
),
device
=
"cuda"
).
t
())
b
=
b
.
contiguous
().
reshape
(
k
,
-
1
)
m_a_scales
=
m
if
per_act_token
else
1
m_a_scales
=
m
if
per_act_token
else
1
n_b_scales
=
n
if
per_out_ch
else
1
n_b_scales
=
n
if
per_out_ch
else
1
scale_a
=
(
torch
.
randn
(
scale_a
=
(
torch
.
randn
(
(
m_a_scales
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
)
(
m_a_scales
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
)
scale_b
=
(
torch
.
randn
(
scale_b
=
(
torch
.
randn
(
(
1
,
n_b_scales
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
)
(
n_b_scales
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
)
# Construct a trivial model with a single layer that calls a CUTLASS kernel
# Construct a trivial model with a single layer that calls a CUTLASS kernel
model
=
CutlassLayer
(
b
,
scale_a
,
scale_b
,
torch
.
b
float16
)
model
=
CutlassLayer
(
b
,
scale_a
,
scale_b
,
torch
.
float16
)
# Run the model with a cuda graph
# Run the model with a cuda graph
stream
=
torch
.
cuda
.
Stream
()
stream
=
torch
.
cuda
.
Stream
()
...
@@ -429,9 +458,9 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
...
@@ -429,9 +458,9 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
g
.
replay
()
g
.
replay
()
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
torch
.
bfloat16
)
scale_b
.
T
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
torch
.
float16
)
#print("baseline:",baseline)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
)
#print("out:",out)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
def
test_cutlass_support_opcheck
():
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm_supports_fp8
,
(
capability
,
))
tests/kernels/test_encoder_decoder_attn.py
View file @
217ee621
...
@@ -751,7 +751,7 @@ def test_encoder_only(
...
@@ -751,7 +751,7 @@ def test_encoder_only(
No KV cache is required for encoder-only attention.
No KV cache is required for encoder-only attention.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPU
s, therefore this test simply is skipped if is_hip().
hcu
s, therefore this test simply is skipped if is_hip().
This test globally forces an override of the usual backend
This test globally forces an override of the usual backend
auto-selection process, forcing the specific backend-under-test
auto-selection process, forcing the specific backend-under-test
...
@@ -860,7 +860,7 @@ def test_e2e_enc_dec_attn(
...
@@ -860,7 +860,7 @@ def test_e2e_enc_dec_attn(
to be utilized.
to be utilized.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPU
s, therefore this test simply is skipped if is_hip().
hcu
s, therefore this test simply is skipped if is_hip().
Note on metadata: there is a single attention metadata structure shared by
Note on metadata: there is a single attention metadata structure shared by
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
...
...
tests/kernels/test_flash_attn.py
View file @
217ee621
...
@@ -8,8 +8,8 @@ if is_hip():
...
@@ -8,8 +8,8 @@ if is_hip():
import
flash_attn
import
flash_attn
else
:
else
:
import
vllm.attention.backends.flash_attn
# noqa: F401
import
vllm.attention.backends.flash_attn
# noqa: F401
from
tests.kernels.utils
import
opcheck
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
),
(
16
,
2
)]
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
),
(
16
,
2
)]
HEAD_SIZES
=
[
128
,
256
]
HEAD_SIZES
=
[
128
,
256
]
...
@@ -132,19 +132,21 @@ if not is_hip():
...
@@ -132,19 +132,21 @@ if not is_hip():
else
:
else
:
test_utils
=
[
"test_faketensor"
]
test_utils
=
[
"test_faketensor"
]
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
,
if
torch_version
.
startswith
(
"2.4"
):
args
=
tuple
(),
from
tests.kernels.utils
import
opcheck
kwargs
=
dict
(
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
,
decode_query
=
query
.
unsqueeze
(
1
),
args
=
tuple
(),
key_cache
=
key_cache
,
kwargs
=
dict
(
value_cache
=
value_cache
,
decode_query
=
query
.
unsqueeze
(
1
),
softmax_scale
=
scale
,
key_cache
=
key_cache
,
causal
=
True
,
value_cache
=
value_cache
,
block_table
=
block_tables
,
softmax_scale
=
scale
,
cache_seqlens
=
kv_lens_tensor
,
causal
=
True
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
block_table
=
block_tables
,
),
cache_seqlens
=
kv_lens_tensor
,
test_utils
=
test_utils
)
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
),
test_utils
=
test_utils
)
ref_output
=
ref_paged_attn
(
ref_output
=
ref_paged_attn
(
query
=
query
,
query
=
query
,
...
@@ -253,23 +255,25 @@ def test_varlen_with_paged_kv(
...
@@ -253,23 +255,25 @@ def test_varlen_with_paged_kv(
test_utils
=
[
"test_faketensor"
]
test_utils
=
[
"test_faketensor"
]
if
not
is_hip
():
if
not
is_hip
():
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_varlen_func
,
if
torch_version
.
startswith
(
"2.4"
):
args
=
tuple
(),
from
tests.kernels.utils
import
opcheck
kwargs
=
dict
(
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_varlen_func
,
q
=
query
,
args
=
tuple
(),
k
=
key_cache
,
kwargs
=
dict
(
v
=
value_cache
,
q
=
query
,
cu_seqlens_q
=
cu_query_lens
,
k
=
key_cache
,
cu_seqlens_k
=
cu_kv_lens
,
v
=
value_cache
,
max_seqlen_q
=
max_query_len
,
cu_seqlens_q
=
cu_query_lens
,
max_seqlen_k
=
max_kv_len
,
cu_seqlens_k
=
cu_kv_lens
,
softmax_scale
=
scale
,
max_seqlen_q
=
max_query_len
,
causal
=
True
,
max_seqlen_k
=
max_kv_len
,
window_size
=
window_size
,
softmax_scale
=
scale
,
block_table
=
block_tables
,
causal
=
True
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
window_size
=
window_size
,
),
block_table
=
block_tables
,
test_utils
=
test_utils
)
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
),
test_utils
=
test_utils
)
ref_output
=
ref_paged_attn
(
ref_output
=
ref_paged_attn
(
query
=
query
,
query
=
query
,
...
...
tests/kernels/test_int8_quant.py
View file @
217ee621
...
@@ -2,9 +2,10 @@ import pytest
...
@@ -2,9 +2,10 @@ import pytest
import
torch
import
torch
from
tests.kernels.quant_utils
import
ref_dynamic_per_token_quant
from
tests.kernels.quant_utils
import
ref_dynamic_per_token_quant
from
tests.kernels.utils
import
opcheck
from
vllm._custom_ops
import
scaled_int8_quant
from
vllm._custom_ops
import
scaled_int8_quant
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
vllm.utils
import
is_hip
from
.utils
import
torch_version
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HIDDEN_SIZES
=
[
16
,
67
,
768
,
2048
,
5120
,
5137
,
8192
,
HIDDEN_SIZES
=
[
16
,
67
,
768
,
2048
,
5120
,
5137
,
8192
,
...
@@ -14,30 +15,35 @@ SEEDS = [0]
...
@@ -14,30 +15,35 @@ SEEDS = [0]
SCALE
=
[
0.1
,
0.5
,
0.8
,
1.2
,
2.1
]
SCALE
=
[
0.1
,
0.5
,
0.8
,
1.2
,
2.1
]
def
opcheck_int8_quant_static
(
output
,
input
,
scale
,
azp
=
None
):
if
torch_version
.
startswith
(
"2.4"
):
if
azp
is
None
:
from
tests.kernels.utils
import
opcheck
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
(
output
,
input
,
scale
,
None
))
else
:
def
opcheck_int8_quant_static
(
output
,
input
,
scale
,
azp
=
None
):
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
if
azp
is
None
:
(
output
,
input
,
scale
,
azp
))
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
(
output
,
input
,
scale
,
None
))
else
:
def
opcheck_int8_quant_dynamic
(
output
,
input
,
symmetric
=
True
):
opcheck
(
torch
.
ops
.
_C
.
static_scaled_int8_quant
,
scale
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
(
output
,
input
,
scale
,
azp
))
device
=
input
.
device
,
dtype
=
torch
.
float32
)
if
symmetric
:
def
opcheck_int8_quant_dynamic
(
output
,
input
,
symmetric
=
True
):
opcheck
(
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
,
scale
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
(
output
,
input
,
scale
,
None
))
device
=
input
.
device
,
else
:
dtype
=
torch
.
float32
)
azp
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
if
symmetric
:
device
=
input
.
device
,
opcheck
(
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
,
dtype
=
torch
.
int32
)
(
output
,
input
,
scale
,
None
))
opcheck
(
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
,
else
:
(
output
,
input
,
scale
,
azp
))
azp
=
torch
.
empty
((
input
.
numel
()
//
input
.
shape
[
-
1
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
int32
)
opcheck
(
torch
.
ops
.
_C
.
dynamic_scaled_int8_quant
,
(
output
,
input
,
scale
,
azp
))
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Currently, there is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
...
@@ -54,13 +60,21 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
...
@@ -54,13 +60,21 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
# kernel
# kernel
ops_out
,
ops_scales
,
_
=
scaled_int8_quant
(
x
)
ops_out
,
ops_scales
,
_
=
scaled_int8_quant
(
x
)
torch
.
testing
.
assert_close
(
ops_scales
,
ref_scales
)
if
torch_version
.
startswith
(
"2.3"
):
# big atol to account for rounding errors
torch
.
allclose
(
ops_scales
,
ref_scales
)
torch
.
testing
.
assert_close
(
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
torch
.
allclose
(
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
opcheck_int8_quant_dynamic
(
ops_out
,
x
)
torch
.
testing
.
assert_close
(
ops_scales
,
ref_scales
)
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Currently, there is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
...
@@ -94,13 +108,20 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
...
@@ -94,13 +108,20 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
if
(
not
torch
.
allclose
(
scales_out
,
scales
)):
if
(
not
torch
.
allclose
(
scales_out
,
scales
)):
print
(
torch
.
argmax
(
torch
.
abs
(
scales_out
-
scales
)))
print
(
torch
.
argmax
(
torch
.
abs
(
scales_out
-
scales
)))
torch
.
testing
.
assert_close
(
scales_out
,
scales
)
if
torch_version
.
startswith
(
"2.3"
):
# big atol to account for rounding errors
torch
.
allclose
(
scales_out
,
scales
)
torch
.
testing
.
assert_close
(
azp_out
,
azps
,
atol
=
1
,
rtol
=
0.0
)
torch
.
allclose
(
azp_out
,
azps
,
atol
=
1
,
rtol
=
0.0
)
# if AZP is off by 1, after rounding-to-even, the output may be off by 2
torch
.
allclose
(
ops_out
,
torch_out
,
atol
=
2
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
ops_out
,
torch_out
,
atol
=
2
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
scales_out
,
scales
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
,
False
)
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
azp_out
,
azps
,
atol
=
1
,
rtol
=
0.0
)
# if AZP is off by 1, after rounding-to-even, the output may be off by 2
torch
.
testing
.
assert_close
(
ops_out
,
torch_out
,
atol
=
2
,
rtol
=
0.0
)
opcheck_int8_quant_dynamic
(
ops_out
,
x
,
False
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
@@ -122,10 +143,15 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
...
@@ -122,10 +143,15 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits
.
max
).
to
(
torch
.
int8
)
int8_traits
.
max
).
to
(
torch
.
int8
)
out2
,
_
,
_
=
scaled_int8_quant
(
x
,
scale_arg
)
out2
,
_
,
_
=
scaled_int8_quant
(
x
,
scale_arg
)
# big atol to account for rounding errors
if
torch_version
.
startswith
(
"2.3"
):
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
torch
.
allclose
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
@@ -152,11 +178,16 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
...
@@ -152,11 +178,16 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out2
,
x
,
scale_arg
,
azp_arg
)
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out2
,
x
,
scale_arg
,
azp_arg
)
# big atol to account for rounding errors
if
torch_version
.
startswith
(
"2.3"
):
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
torch
.
allclose
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
elif
torch_version
.
startswith
(
"2.4"
):
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
,
azp_arg
)
# big atol to account for rounding errors
torch
.
testing
.
assert_close
(
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
opcheck_int8_quant_static
(
out2
,
x
,
scale_arg
,
azp_arg
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"is_max"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"is_max"
,
[
True
,
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -187,4 +218,9 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
...
@@ -187,4 +218,9 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
out
=
torch
.
empty_like
(
expected
)
out
=
torch
.
empty_like
(
expected
)
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out
,
x
,
scale
,
azp
)
torch
.
ops
.
_C
.
static_scaled_int8_quant
(
out
,
x
,
scale
,
azp
)
torch
.
testing
.
assert_close
(
expected
,
out
,
atol
=
0
,
rtol
=
0
)
if
torch_version
.
startswith
(
"2.3"
):
torch
.
allclose
(
expected
,
out
,
atol
=
0
,
rtol
=
0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
expected
,
out
,
atol
=
0
,
rtol
=
0
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
tests/kernels/test_layernorm.py
View file @
217ee621
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
4096
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
4096
]
# Arbitrary values for testing
...
@@ -47,15 +47,25 @@ def test_rms_norm(
...
@@ -47,15 +47,25 @@ def test_rms_norm(
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.
# numerical errors than other operators because they involve reductions.
# Therefore, we use a larger tolerance.
# Therefore, we use a larger tolerance.
if
add_residual
:
if
torch_version
.
startswith
(
"2.3"
):
torch
.
testing
.
assert_close
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
if
add_residual
:
torch
.
testing
.
assert_close
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
allclose
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
torch
.
allclose
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
elif
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
if
add_residual
:
torch
.
testing
.
assert_close
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
torch
.
testing
.
assert_close
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
residual
is
not
None
:
if
residual
is
not
None
:
opcheck
(
torch
.
ops
.
_C
.
fused_add_rms_norm
,
opcheck
(
torch
.
ops
.
_C
.
fused_add_rms_norm
,
(
x
,
residual
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
(
x
,
residual
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
else
:
opcheck
(
torch
.
ops
.
_C
.
rms_norm
,
(
out
,
x
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
else
:
else
:
opcheck
(
torch
.
ops
.
_C
.
rms_norm
,
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
(
out
,
x
,
layer
.
weight
.
data
,
layer
.
variance_epsilon
))
tests/kernels/test_moe.py
View file @
217ee621
...
@@ -9,7 +9,6 @@ import torch
...
@@ -9,7 +9,6 @@ import torch
from
transformers
import
MixtralConfig
from
transformers
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
...
@@ -22,6 +21,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
...
@@ -22,6 +21,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.model_executor.models.mixtral
import
MixtralMoE
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
.utils
import
torch_version
from
vllm.utils
import
is_hip
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
...
@@ -76,7 +77,12 @@ def test_fused_moe(
...
@@ -76,7 +77,12 @@ def test_fused_moe(
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
if
torch_version
.
startswith
(
"2.3"
):
assert
torch
.
allclose
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
...
@@ -120,11 +126,18 @@ def test_mixtral_moe(dtype: torch.dtype):
...
@@ -120,11 +126,18 @@ def test_mixtral_moe(dtype: torch.dtype):
torch
.
float16
:
1e-3
,
torch
.
float16
:
1e-3
,
torch
.
bfloat16
:
1e-2
,
torch
.
bfloat16
:
1e-2
,
}
}
if
torch_version
.
startswith
(
"2.3"
):
torch
.
testing
.
assert_close
(
hf_states
.
flatten
(
0
,
1
),
assert
torch
.
allclose
(
hf_states
.
flatten
(
0
,
1
),
vllm_states
,
vllm_states
,
rtol
=
mixtral_moe_tol
[
dtype
],
rtol
=
mixtral_moe_tol
[
dtype
],
atol
=
mixtral_moe_tol
[
dtype
])
atol
=
mixtral_moe_tol
[
dtype
])
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
hf_states
.
flatten
(
0
,
1
),
vllm_states
,
rtol
=
mixtral_moe_tol
[
dtype
],
atol
=
mixtral_moe_tol
[
dtype
])
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
def
stack_and_dev
(
tensors
:
List
[
torch
.
Tensor
]):
def
stack_and_dev
(
tensors
:
List
[
torch
.
Tensor
]):
...
@@ -137,6 +150,8 @@ def compute_max_diff(output, output_ref):
...
@@ -137,6 +150,8 @@ def compute_max_diff(output, output_ref):
torch
.
abs
(
output_ref
))
torch
.
abs
(
output_ref
))
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Currently, there is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
,
512
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
,
512
])
...
@@ -256,12 +271,14 @@ def test_fused_marlin_moe(
...
@@ -256,12 +271,14 @@ def test_fused_marlin_moe(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
a
.
device
)
device
=
a
.
device
)
opcheck
(
torch
.
ops
.
_moe_C
.
topk_softmax
,
(
if
torch_version
.
startswith
(
"2.4"
):
topk_weights
,
from
tests.kernels.utils
import
opcheck
topk_ids
,
opcheck
(
torch
.
ops
.
_moe_C
.
topk_softmax
,
(
token_expert_indicies
,
topk_weights
,
score
.
float
(),
topk_ids
,
))
token_expert_indicies
,
score
.
float
(),
))
block_size_m
=
4
block_size_m
=
4
...
@@ -274,12 +291,16 @@ def test_fused_marlin_moe(
...
@@ -274,12 +291,16 @@ def test_fused_marlin_moe(
device
=
"cuda"
,
device
=
"cuda"
,
requires_grad
=
False
)
requires_grad
=
False
)
opcheck
(
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
,
if
torch_version
.
startswith
(
"2.4"
):
(
a
,
qweight1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
from
tests.kernels.utils
import
opcheck
scales1
,
g_idx1
,
sort_indices1
,
workspace
,
quant_type
,
m
,
opcheck
(
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
,
2
*
n
,
k
,
True
,
e
,
topk
,
block_size_m
,
True
,
False
))
(
a
,
qweight1
,
sorted_token_ids
,
topk_weights
,
topk_ids
,
scales1
,
g_idx1
,
sort_indices1
,
workspace
,
quant_type
,
m
,
2
*
n
,
k
,
True
,
e
,
topk
,
block_size_m
,
True
,
False
))
@
pytest
.
mark
.
skipif
(
is_hip
(),
reason
=
"Currently, there is not supported on ROCm."
)
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
"don't run it in automated tests."
)
"don't run it in automated tests."
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
...
@@ -373,7 +394,8 @@ def test_moe_align_block_size_opcheck():
...
@@ -373,7 +394,8 @@ def test_moe_align_block_size_opcheck():
num_tokens_post_pad
=
torch
.
empty
((
1
),
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
device
=
topk_ids
.
device
)
if
torch_version
.
startswith
(
"2.4"
):
opcheck
(
torch
.
ops
.
_C
.
moe_align_block_size
,
from
tests.kernels.utils
import
opcheck
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
opcheck
(
torch
.
ops
.
_C
.
moe_align_block_size
,
num_tokens_post_pad
))
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
))
tests/kernels/test_pos_encoding.py
View file @
217ee621
...
@@ -8,6 +8,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -8,6 +8,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.utils
import
seed_everything
from
vllm.utils
import
seed_everything
from
.allclose_default
import
get_default_atol
,
get_default_rtol
from
.allclose_default
import
get_default_atol
,
get_default_rtol
from
.utils
import
torch_version
IS_NEOX_STYLE
=
[
True
,
False
]
IS_NEOX_STYLE
=
[
True
,
False
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
@@ -18,7 +19,7 @@ BATCH_SIZES = [1, 5] # Arbitrary values for testing
...
@@ -18,7 +19,7 @@ BATCH_SIZES = [1, 5] # Arbitrary values for testing
SEQ_LENS
=
[
11
,
8192
]
# Arbitrary values for testing
SEQ_LENS
=
[
11
,
8192
]
# Arbitrary values for testing
SEEDS
=
[
0
]
SEEDS
=
[
0
]
CUDA_DEVICES
=
[
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
1
)
]
]
...
@@ -67,14 +68,26 @@ def test_rotary_embedding(
...
@@ -67,14 +68,26 @@ def test_rotary_embedding(
ref_query
,
ref_key
=
rope
.
forward_native
(
positions
,
query
,
key
)
ref_query
,
ref_key
=
rope
.
forward_native
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
)
# Compare the results.
# Compare the results.
torch
.
testing
.
assert_close
(
out_query
,
if
torch_version
.
startswith
(
"2.3"
):
ref_query
,
torch
.
allclose
(
out_query
,
atol
=
get_default_atol
(
out_query
),
ref_query
,
rtol
=
get_default_rtol
(
out_query
))
atol
=
get_default_atol
(
out_query
),
torch
.
testing
.
assert_close
(
out_key
,
rtol
=
get_default_rtol
(
out_query
))
ref_key
,
torch
.
allclose
(
out_key
,
atol
=
get_default_atol
(
out_key
),
ref_key
,
rtol
=
get_default_rtol
(
out_key
))
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
elif
torch_version
.
startswith
(
"2.4"
):
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
@@ -126,15 +139,27 @@ def test_batched_rotary_embedding(
...
@@ -126,15 +139,27 @@ def test_batched_rotary_embedding(
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
))
device
=
device
))
# Compare the results.
if
torch_version
.
startswith
(
"2.3"
):
torch
.
testing
.
assert_close
(
out_query
,
torch
.
allclose
(
out_query
,
ref_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
torch
.
allclose
(
out_key
,
ref_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
rtol
=
get_default_rtol
(
out_key
))
elif
torch_version
.
startswith
(
"2.4"
):
# Compare the results.
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
@@ -195,16 +220,27 @@ def test_batched_rotary_embedding_multi_lora(
...
@@ -195,16 +220,27 @@ def test_batched_rotary_embedding_multi_lora(
query_offsets
)
query_offsets
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
,
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
,
query_offsets
.
flatten
())
query_offsets
.
flatten
())
# Compare the results.
if
torch_version
.
startswith
(
"2.3"
):
torch
.
testing
.
assert_close
(
out_query
,
torch
.
allclose
(
out_query
,
ref_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
torch
.
allclose
(
out_key
,
ref_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
rtol
=
get_default_rtol
(
out_key
))
elif
torch_version
.
startswith
(
"2.4"
):
# Compare the results.
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
print
(
f
"PyTorch version
{
torch_version
}
is not specifically handled."
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_rope_module_cache
():
def
test_rope_module_cache
():
...
...
tests/kernels/test_rotary_embedding.py
View file @
217ee621
...
@@ -7,8 +7,11 @@ from typing import Optional
...
@@ -7,8 +7,11 @@ from typing import Optional
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
.utils
import
torch_version
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
def
rotary_embedding_opcheck
(
rot
,
def
rotary_embedding_opcheck
(
rot
,
...
@@ -30,6 +33,8 @@ def rotary_embedding_opcheck(rot,
...
@@ -30,6 +33,8 @@ def rotary_embedding_opcheck(rot,
rot
.
is_neox_style
))
rot
.
is_neox_style
))
@
pytest
.
mark
.
skipif
(
torch_version
.
startswith
(
"2.3"
),
reason
=
"Need torch2.4."
)
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"max_position"
,
[
11
,
4096
,
32768
])
@
pytest
.
mark
.
parametrize
(
"max_position"
,
[
11
,
4096
,
32768
])
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
[
True
,
False
])
...
...
tests/kernels/test_utils.py
View file @
217ee621
...
@@ -5,14 +5,17 @@ Tests for miscellaneous utilities
...
@@ -5,14 +5,17 @@ Tests for miscellaneous utilities
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
.utils
import
torch_version
if
torch_version
.
startswith
(
"2.4"
):
from
tests.kernels.utils
import
opcheck
def
test_convert_fp8_opcheck
():
data
=
torch
.
randn
((
256
,
256
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
# def test_convert_fp8_opcheck():
result
=
torch
.
empty_like
(
data
,
dtype
=
torch
.
float8_e4m3fn
)
# data = torch.randn((256, 256), dtype=torch.float32, device="cuda")
opcheck
(
torch
.
ops
.
_C_cache_ops
.
convert_fp8
,
(
result
,
data
,
1.0
,
"fp8"
))
# result = torch.empty_like(data, dtype=torch.float8_e4m3fn)
# opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
...
...
tests/kernels/test_aqlm.py
→
tests/kernels/
un
test_aqlm.py
View file @
217ee621
File moved
tests/kernels/test_awq.py
→
tests/kernels/
un
test_awq.py
View file @
217ee621
File moved
tests/kernels/test_causal_conv1d.py
→
tests/kernels/
un
test_causal_conv1d.py
View file @
217ee621
File moved
tests/kernels/test_flashinfer.py
→
tests/kernels/
un
test_flashinfer.py
View file @
217ee621
File moved
tests/kernels/test_fp8_quant.py
→
tests/kernels/
un
test_fp8_quant.py
View file @
217ee621
File moved
tests/kernels/test_ggml.py
→
tests/kernels/
un
test_ggml.py
View file @
217ee621
File moved
tests/kernels/test_gguf.py
→
tests/kernels/
un
test_gguf.py
View file @
217ee621
File moved
tests/kernels/test_gptq.py
→
tests/kernels/
un
test_gptq.py
View file @
217ee621
File moved
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment