Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d75f22e
Commit
8d75f22e
authored
Dec 13, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori
parents
ce888aa4
7d80c73d
Changes
679
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1062 additions
and
252 deletions
+1062
-252
tests/entrypoints/test_responses_utils.py
tests/entrypoints/test_responses_utils.py
+45
-7
tests/kernels/attention/test_mha_attn.py
tests/kernels/attention/test_mha_attn.py
+9
-2
tests/kernels/core/test_fused_quant_layernorm.py
tests/kernels/core/test_fused_quant_layernorm.py
+63
-19
tests/kernels/core/test_mrope.py
tests/kernels/core/test_mrope.py
+2
-6
tests/kernels/mamba/test_mamba_ssm.py
tests/kernels/mamba/test_mamba_ssm.py
+325
-0
tests/kernels/moe/modular_kernel_tools/mk_objects.py
tests/kernels/moe/modular_kernel_tools/mk_objects.py
+0
-17
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+15
-1
tests/kernels/moe/test_moe_align_block_size.py
tests/kernels/moe/test_moe_align_block_size.py
+17
-5
tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py
...s/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py
+86
-0
tests/kernels/quant_utils.py
tests/kernels/quant_utils.py
+1
-1
tests/kernels/quantization/test_block_fp8.py
tests/kernels/quantization/test_block_fp8.py
+14
-3
tests/kernels/quantization/test_cutlass_scaled_mm.py
tests/kernels/quantization/test_cutlass_scaled_mm.py
+3
-0
tests/kernels/quantization/test_cutlass_w4a8.py
tests/kernels/quantization/test_cutlass_w4a8.py
+44
-5
tests/kernels/quantization/test_cutlass_w4a8_moe.py
tests/kernels/quantization/test_cutlass_w4a8_moe.py
+340
-0
tests/kernels/quantization/test_hadacore.py
tests/kernels/quantization/test_hadacore.py
+7
-0
tests/kernels/quantization/test_machete_mm.py
tests/kernels/quantization/test_machete_mm.py
+6
-0
tests/kernels/quantization/test_marlin_gemm.py
tests/kernels/quantization/test_marlin_gemm.py
+8
-0
tests/kernels/test_top_k_per_row.py
tests/kernels/test_top_k_per_row.py
+77
-18
tests/kv_transfer/test_lookup_buffer.py
tests/kv_transfer/test_lookup_buffer.py
+0
-160
tests/kv_transfer/test_lookup_buffer.sh
tests/kv_transfer/test_lookup_buffer.sh
+0
-8
No files found.
Too many changes to show.
To preserve performance only
679 of 679+
files are displayed.
Plain diff
Email patch
tests/entrypoints/test_responses_utils.py
View file @
8d75f22e
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
openai.types.responses.response_function_tool_call
import
ResponseFunctionToolCall
from
openai.types.responses.response_function_tool_call_output_item
import
(
ResponseFunctionToolCallOutputItem
,
)
...
...
@@ -14,7 +15,8 @@ from openai.types.responses.response_reasoning_item import (
)
from
vllm.entrypoints.responses_utils
import
(
construct_chat_message_with_tool_call
,
_construct_single_message_from_response_item
,
construct_chat_messages_with_tool_call
,
convert_tool_responses_to_completions_format
,
)
...
...
@@ -42,7 +44,43 @@ class TestResponsesUtils:
assert
result
==
{
"type"
:
"function"
,
"function"
:
input_tool
}
def
test_construct_chat_message_with_tool_call
(
self
):
def
test_construct_chat_messages_with_tool_call
(
self
):
"""Test construction of chat messages with tool calls."""
reasoning_item
=
ResponseReasoningItem
(
id
=
"lol"
,
summary
=
[],
type
=
"reasoning"
,
content
=
[
Content
(
text
=
"Leroy Jenkins"
,
type
=
"reasoning_text"
,
)
],
encrypted_content
=
None
,
status
=
None
,
)
mcp_tool_item
=
ResponseFunctionToolCall
(
id
=
"mcp_123"
,
call_id
=
"call_123"
,
type
=
"function_call"
,
status
=
"completed"
,
name
=
"python"
,
arguments
=
'{"code": "123+456"}'
,
)
input_items
=
[
reasoning_item
,
mcp_tool_item
]
messages
=
construct_chat_messages_with_tool_call
(
input_items
)
assert
len
(
messages
)
==
1
message
=
messages
[
0
]
assert
message
[
"role"
]
==
"assistant"
assert
message
[
"reasoning"
]
==
"Leroy Jenkins"
assert
message
[
"tool_calls"
][
0
][
"id"
]
==
"call_123"
assert
message
[
"tool_calls"
][
0
][
"function"
][
"name"
]
==
"python"
assert
(
message
[
"tool_calls"
][
0
][
"function"
][
"arguments"
]
==
'{"code": "123+456"}'
)
def
test_construct_single_message_from_response_item
(
self
):
item
=
ResponseReasoningItem
(
id
=
"lol"
,
summary
=
[],
...
...
@@ -56,7 +94,7 @@ class TestResponsesUtils:
encrypted_content
=
None
,
status
=
None
,
)
formatted_item
=
construct_
chat
_message_
with_tool_call
(
item
)
formatted_item
=
_
construct_
single
_message_
from_response_item
(
item
)
assert
formatted_item
[
"role"
]
==
"assistant"
assert
formatted_item
[
"reasoning"
]
==
"Leroy Jenkins"
...
...
@@ -74,7 +112,7 @@ class TestResponsesUtils:
status
=
None
,
)
formatted_item
=
construct_
chat
_message_
with_tool_call
(
item
)
formatted_item
=
_
construct_
single
_message_
from_response_item
(
item
)
assert
formatted_item
[
"role"
]
==
"assistant"
assert
(
formatted_item
[
"reasoning"
]
...
...
@@ -88,7 +126,7 @@ class TestResponsesUtils:
output
=
"1234"
,
status
=
"completed"
,
)
formatted_item
=
construct_
chat
_message_
with_tool_call
(
tool_call_output
)
formatted_item
=
_
construct_
single
_message_
from_response_item
(
tool_call_output
)
assert
formatted_item
[
"role"
]
==
"tool"
assert
formatted_item
[
"content"
]
==
"1234"
assert
formatted_item
[
"tool_call_id"
]
==
"temp"
...
...
@@ -102,7 +140,7 @@ class TestResponsesUtils:
status
=
None
,
)
with
pytest
.
raises
(
ValueError
):
construct_
chat
_message_
with_tool_call
(
item
)
_
construct_
single
_message_
from_response_item
(
item
)
output_item
=
ResponseOutputMessage
(
id
=
"msg_bf585bbbe3d500e0"
,
...
...
@@ -119,6 +157,6 @@ class TestResponsesUtils:
type
=
"message"
,
)
formatted_item
=
construct_
chat
_message_
with_tool_call
(
output_item
)
formatted_item
=
_
construct_
single
_message_
from_response_item
(
output_item
)
assert
formatted_item
[
"role"
]
==
"assistant"
assert
formatted_item
[
"content"
]
==
"dongyi"
tests/kernels/attention/test_mha_attn.py
View file @
8d75f22e
...
...
@@ -26,7 +26,14 @@ def clear_cache():
_cached_get_attn_backend
.
cache_clear
()
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
devices
=
[
"cpu"
]
if
current_platform
.
is_cuda
():
devices
.
append
(
"cuda"
)
if
current_platform
.
is_rocm
():
devices
.
append
(
"hip"
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_mha_attn_platform
(
device
:
str
):
"""
Test the attention selector between different platform and device.
...
...
@@ -46,7 +53,7 @@ def test_mha_attn_platform(device: str):
patch
(
"vllm.model_executor.models.vision.current_platform"
,
RocmPlatform
()),
):
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
TORCH_SDPA
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
else
:
# Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention
...
...
tests/kernels/core/test_fused_quant_layernorm.py
View file @
8d75f22e
...
...
@@ -8,6 +8,12 @@ import torch
import
vllm._custom_ops
as
ops
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
,
)
from
vllm.model_executor.layers.quantization.utils.int8_utils
import
(
per_token_group_quant_int8
,
)
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float
]
QUANT_DTYPES
=
[
torch
.
int8
,
torch
.
float8_e4m3fn
]
...
...
@@ -21,6 +27,7 @@ NUM_TOKENS_HIDDEN_SIZES = [
ADD_RESIDUAL
=
[
False
,
True
]
SCALE_UBS
=
[
True
,
False
]
GROUP_SIZES
=
[
None
,
[
1
,
64
],
[
1
,
128
]]
SEEDS
=
[
0
]
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)]
...
...
@@ -45,12 +52,13 @@ def ref_rms_norm(
return
out
,
residual
def
ref_dynamic_per_token_quant
(
def
ref_dynamic_per_token_
or_block_
quant
(
rms_norm_layer
:
RMSNorm
,
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
residual
:
torch
.
Tensor
|
None
,
scale_ub
:
torch
.
Tensor
|
None
,
group_size
:
list
[
int
]
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
if
scale_ub
is
not
None
:
assert
quant_dtype
==
torch
.
float8_e4m3fn
...
...
@@ -59,6 +67,17 @@ def ref_dynamic_per_token_quant(
torch_out
,
residual
=
ref_rms_norm
(
rms_norm_layer
,
x
,
residual
)
# Quant
if
group_size
is
not
None
:
if
quant_dtype
==
torch
.
float8_e4m3fn
:
torch_out
,
scales
=
per_token_group_quant_fp8
(
torch_out
,
group_size
=
group_size
[
1
],
use_ue8m0
=
False
)
else
:
assert
quant_dtype
==
torch
.
int8
torch_out
,
scales
=
per_token_group_quant_int8
(
torch_out
,
group_size
=
group_size
[
1
]
)
else
:
if
quant_dtype
==
torch
.
float8_e4m3fn
:
torch_out
,
scales
=
ops
.
scaled_fp8_quant
(
torch_out
,
scale_ub
=
scale_ub
,
use_per_token_if_dynamic
=
True
...
...
@@ -76,21 +95,29 @@ def ref_impl(
quant_dtype
:
torch
.
dtype
,
residual
:
torch
.
Tensor
|
None
,
scale_ub
:
torch
.
Tensor
|
None
,
group_size
:
list
[
int
]
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
return
ref_dynamic_per_token_quant
(
rms_norm_layer
,
x
,
quant_dtype
,
residual
,
scale_ub
return
ref_dynamic_per_token_
or_block_
quant
(
rms_norm_layer
,
x
,
quant_dtype
,
residual
,
scale_ub
,
group_size
)
def
ops_dynamic_per_token_quant
(
def
ops_dynamic_per_token_
or_block_
quant
(
weight
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
residual
:
torch
.
Tensor
|
None
,
scale_ub
:
torch
.
Tensor
|
None
,
group_size
:
list
[
int
]
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
if
residual
is
not
None
:
residual
=
residual
.
clone
()
if
group_size
is
not
None
:
out
,
scales
=
ops
.
rms_norm_per_block_quant
(
x
,
weight
,
EPS
,
quant_dtype
,
group_size
,
scale_ub
,
residual
,
True
)
scales
=
scales
.
contiguous
()
else
:
out
,
scales
=
ops
.
rms_norm_dynamic_per_token_quant
(
x
,
weight
,
EPS
,
quant_dtype
,
scale_ub
,
residual
)
...
...
@@ -103,8 +130,11 @@ def ops_impl(
quant_dtype
:
torch
.
dtype
,
residual
:
torch
.
Tensor
|
None
,
scale_ub
:
torch
.
Tensor
|
None
,
group_size
:
list
[
int
]
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
return
ops_dynamic_per_token_quant
(
weight
,
x
,
quant_dtype
,
residual
,
scale_ub
)
return
ops_dynamic_per_token_or_block_quant
(
weight
,
x
,
quant_dtype
,
residual
,
scale_ub
,
group_size
)
@
pytest
.
mark
.
parametrize
(
"num_tokens, hidden_size"
,
NUM_TOKENS_HIDDEN_SIZES
)
...
...
@@ -112,6 +142,7 @@ def ops_impl(
@
pytest
.
mark
.
parametrize
(
"has_scale_ub"
,
SCALE_UBS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
QUANT_DTYPES
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
...
...
@@ -122,6 +153,7 @@ def test_rms_norm(
has_scale_ub
:
bool
,
dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
group_size
:
list
[
int
]
|
None
,
seed
:
int
,
device
:
str
,
)
->
None
:
...
...
@@ -130,6 +162,14 @@ def test_rms_norm(
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
if
group_size
is
not
None
and
hidden_size
%
group_size
[
1
]
!=
0
:
# skip
return
if
group_size
is
not
None
and
has_scale_ub
:
# blockwise baseline doesn't support scale_ub
return
if
has_scale_ub
and
quant_dtype
!=
torch
.
float8_e4m3fn
:
# skip
return
...
...
@@ -150,10 +190,10 @@ def test_rms_norm(
scale_ub
=
None
ref_out
,
ref_scales
,
ref_residual
=
ref_impl
(
layer
,
x
,
quant_dtype
,
residual
,
scale_ub
layer
,
x
,
quant_dtype
,
residual
,
scale_ub
,
group_size
)
ops_out
,
ops_scales
,
ops_residual
=
ops_impl
(
layer
.
weight
,
x
,
quant_dtype
,
residual
,
scale_ub
layer
.
weight
,
x
,
quant_dtype
,
residual
,
scale_ub
,
group_size
)
assert
ref_out
.
dtype
==
quant_dtype
...
...
@@ -166,11 +206,15 @@ def test_rms_norm(
assert
torch
.
allclose
(
ref_scales
,
ops_scales
)
a
=
ref_out
.
to
(
dtype
=
torch
.
float32
)
b
=
ops_out
.
to
(
dtype
=
torch
.
float32
)
ok
=
torch
.
allclose
(
a
,
b
)
ok
=
torch
.
allclose
(
a
,
b
,
atol
=
1e-6
)
if
not
ok
:
# fallback: compare dequantized values with relaxed tolerance
if
group_size
is
None
:
a_deq
=
a
*
ref_scales
.
view
(
-
1
,
1
)
b_deq
=
b
*
ops_scales
.
view
(
-
1
,
1
)
else
:
a_deq
=
a
*
ref_scales
.
repeat_interleave
(
group_size
[
1
],
dim
=
1
)
b_deq
=
b
*
ops_scales
.
repeat_interleave
(
group_size
[
1
],
dim
=
1
)
# NOTE: It is possible that some future test cases trigger this
# max diff due to precision issues. If such an error is
# encountered, it's recommended to inspect the differences between
...
...
tests/kernels/core/test_mrope.py
View file @
8d75f22e
...
...
@@ -113,12 +113,10 @@ def test_mrope(
is_neox_style
=
True
max_position
=
config
.
max_position_embeddings
partial_rotary_factor
=
getattr
(
config
,
"partial_rotary_factor"
,
1.0
)
rotary_dim
=
int
(
head_dim
*
partial_rotary_factor
)
mrope_helper_class
=
get_rope
(
head_size
=
head_dim
,
rotary_dim
=
rotary
_dim
,
rotary_dim
=
head
_dim
,
max_position
=
max_position
,
is_neox_style
=
is_neox_style
,
rope_parameters
=
config
.
rope_parameters
,
...
...
@@ -184,12 +182,10 @@ def test_mrope_torch_compile_tracing(
)
is_neox_style
=
True
max_position
=
config
.
max_position_embeddings
partial_rotary_factor
=
getattr
(
config
,
"partial_rotary_factor"
,
1.0
)
rotary_dim
=
int
(
head_dim
*
partial_rotary_factor
)
mrope_helper_class
=
get_rope
(
head_size
=
head_dim
,
rotary_dim
=
rotary
_dim
,
rotary_dim
=
head
_dim
,
max_position
=
max_position
,
is_neox_style
=
is_neox_style
,
rope_parameters
=
config
.
rope_parameters
,
...
...
tests/kernels/mamba/test_mamba_ssm.py
View file @
8d75f22e
...
...
@@ -425,6 +425,80 @@ def test_selective_state_update(dim, dstate, has_z, itype):
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"has_z"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"dstate"
,
[
16
,
64
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
1
,
2
,
4
])
def
test_selective_state_update_varlen
(
dim
,
dstate
,
has_z
,
itype
,
max_seq_len
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
5e-3
,
1e-2
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
5e-2
,
1.5e-1
if
torch
.
version
.
hip
:
atol
*=
2
# set seed
current_platform
.
seed_everything
(
0
)
batch_size
=
4
token_counts
=
torch
.
randint
(
1
,
max_seq_len
+
1
,
(
batch_size
,),
device
=
device
)
total_tokens
=
int
(
token_counts
.
sum
().
item
())
cu_seqlens
=
torch
.
tensor
(
[
0
]
+
torch
.
cumsum
(
token_counts
,
dim
=
0
).
tolist
(),
dtype
=
torch
.
int32
,
device
=
device
,
)
state
=
torch
.
randn
(
batch_size
,
dim
,
dstate
,
dtype
=
itype
,
device
=
device
)
x
=
torch
.
randn
(
total_tokens
,
dim
,
device
=
device
,
dtype
=
itype
)
out
=
torch
.
empty_like
(
x
)
dt
=
torch
.
randn
(
total_tokens
,
dim
,
device
=
device
,
dtype
=
itype
)
dt_bias
=
torch
.
rand
(
dim
,
device
=
device
)
-
4.0
A
=
-
torch
.
rand
(
dim
,
dstate
,
device
=
device
)
-
1.0
B
=
torch
.
randn
(
total_tokens
,
dstate
,
device
=
device
)
C
=
torch
.
randn
(
total_tokens
,
dstate
,
device
=
device
)
D
=
torch
.
randn
(
dim
,
device
=
device
)
z
=
torch
.
randn_like
(
x
)
if
has_z
else
None
state_ref
=
state
.
detach
().
clone
()
selective_state_update
(
state
,
x
,
dt
,
A
,
B
,
C
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
out
=
out
,
cu_seqlens
=
cu_seqlens
,
)
out_ref_list
=
[]
for
seq_idx
in
range
(
batch_size
):
start_idx
=
cu_seqlens
[
seq_idx
].
item
()
end_idx
=
cu_seqlens
[
seq_idx
+
1
].
item
()
num_tokens
=
end_idx
-
start_idx
for
token_idx
in
range
(
num_tokens
):
idx
=
start_idx
+
token_idx
out_ref_list
.
append
(
selective_state_update_ref
(
state_ref
[
seq_idx
:
seq_idx
+
1
],
x
[
idx
:
idx
+
1
],
dt
[
idx
:
idx
+
1
],
A
,
B
[
idx
:
idx
+
1
],
C
[
idx
:
idx
+
1
],
D
=
D
,
z
=
z
[
idx
:
idx
+
1
]
if
has_z
else
None
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
)
)
out_ref
=
torch
.
cat
(
out_ref_list
,
dim
=
0
)
assert
torch
.
allclose
(
state
,
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"wtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
256
,
1024
,
4096
])
...
...
@@ -766,3 +840,254 @@ def test_selective_state_update_with_heads_with_batch_indices(
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
torch
.
allclose
(
state
[
state_indices
,
:],
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"has_z"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"dstate"
,
[
16
,
64
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
2
,
4
])
def
test_selective_state_update_with_num_accepted_tokens
(
dim
,
dstate
,
has_z
,
itype
,
max_seq_len
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
5e-3
,
1e-2
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
5e-2
,
1.5e-1
if
torch
.
version
.
hip
:
atol
*=
2
current_platform
.
seed_everything
(
0
)
batch_size
=
4
tokens_per_seq
=
torch
.
randint
(
1
,
max_seq_len
+
1
,
(
batch_size
,),
device
=
device
)
total_tokens
=
int
(
tokens_per_seq
.
sum
().
item
())
num_accepted_tokens
=
torch
.
randint
(
0
,
max_seq_len
,
(
batch_size
,),
device
=
device
)
num_accepted_tokens
[
0
]
=
0
# Add edge-case of no accepted tokens
num_accepted_tokens
[
1
]
=
max_seq_len
# Add edge-case of all tokens accepted
cu_seqlens
=
torch
.
tensor
(
[
0
]
+
torch
.
cumsum
(
tokens_per_seq
,
dim
=
0
).
tolist
(),
dtype
=
torch
.
int32
,
device
=
device
,
)
total_state_slots
=
50
state
=
torch
.
randn
(
total_state_slots
,
dim
,
dstate
,
dtype
=
itype
,
device
=
device
)
state_batch_indices
=
torch
.
full
(
(
batch_size
,
max_seq_len
),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
device
)
initial_state_slots
=
torch
.
randint
(
0
,
15
,
(
batch_size
,),
device
=
device
,
dtype
=
torch
.
int32
)
for
seq_idx
in
range
(
batch_size
):
token_pos
=
max
(
num_accepted_tokens
[
seq_idx
].
item
()
-
1
,
0
)
state_batch_indices
[
seq_idx
,
token_pos
]
=
initial_state_slots
[
seq_idx
]
dst_state_batch_indices
=
torch
.
full
(
(
batch_size
,
max_seq_len
),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
device
)
slot_offset
=
15
dst_slots_map
=
{}
for
seq_idx
in
range
(
batch_size
):
for
token_idx
in
range
(
tokens_per_seq
[
seq_idx
].
item
()):
dst_state_batch_indices
[
seq_idx
,
token_idx
]
=
slot_offset
dst_slots_map
[(
seq_idx
,
token_idx
)]
=
slot_offset
slot_offset
+=
1
x
=
torch
.
randn
(
total_tokens
,
dim
,
device
=
device
,
dtype
=
itype
)
out
=
torch
.
empty_like
(
x
)
dt
=
torch
.
randn
(
total_tokens
,
dim
,
device
=
device
,
dtype
=
itype
)
dt_bias
=
torch
.
rand
(
dim
,
device
=
device
)
-
4.0
A
=
-
torch
.
rand
(
dim
,
dstate
,
device
=
device
)
-
1.0
B
=
torch
.
randn
(
total_tokens
,
dstate
,
device
=
device
)
C
=
torch
.
randn
(
total_tokens
,
dstate
,
device
=
device
)
D
=
torch
.
randn
(
dim
,
device
=
device
)
z
=
torch
.
randn_like
(
x
)
if
has_z
else
None
state_ref_intermediate
=
{}
out_ref_list
=
[]
for
seq_idx
in
range
(
batch_size
):
seq_start
=
cu_seqlens
[
seq_idx
].
item
()
seq_end
=
cu_seqlens
[
seq_idx
+
1
].
item
()
num_tokens
=
seq_end
-
seq_start
token_pos
=
max
(
num_accepted_tokens
[
seq_idx
].
item
()
-
1
,
0
)
initial_slot
=
state_batch_indices
[
seq_idx
,
token_pos
].
item
()
state_seq
=
state
[
initial_slot
:
initial_slot
+
1
].
clone
()
for
token_idx
in
range
(
num_tokens
):
global_idx
=
seq_start
+
token_idx
out_token
=
selective_state_update_ref
(
state_seq
,
x
[
global_idx
:
global_idx
+
1
],
dt
[
global_idx
:
global_idx
+
1
],
A
,
B
[
global_idx
:
global_idx
+
1
],
C
[
global_idx
:
global_idx
+
1
],
D
=
D
,
z
=
z
[
global_idx
:
global_idx
+
1
]
if
has_z
else
None
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
)
out_ref_list
.
append
(
out_token
)
state_ref_intermediate
[(
seq_idx
,
token_idx
)]
=
state_seq
.
clone
()
out_ref
=
torch
.
cat
(
out_ref_list
,
dim
=
0
)
selective_state_update
(
state
,
x
,
dt
,
A
,
B
,
C
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
out
=
out
,
cu_seqlens
=
cu_seqlens
,
state_batch_indices
=
state_batch_indices
,
dst_state_batch_indices
=
dst_state_batch_indices
,
num_accepted_tokens
=
num_accepted_tokens
,
pad_slot_id
=
PAD_SLOT_ID
,
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
for
seq_idx
in
range
(
batch_size
):
num_tokens
=
tokens_per_seq
[
seq_idx
].
item
()
for
token_idx
in
range
(
num_tokens
):
dst_slot
=
dst_slots_map
[(
seq_idx
,
token_idx
)]
state_ref
=
state_ref_intermediate
[(
seq_idx
,
token_idx
)].
squeeze
(
0
)
assert
torch
.
allclose
(
state
[
dst_slot
],
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"has_z"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"dstate"
,
[
16
,
64
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
2
,
4
])
def
test_selective_state_update_varlen_with_num_accepted
(
dim
,
dstate
,
has_z
,
itype
,
max_seq_len
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
5e-3
,
1e-2
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
5e-2
,
1.5e-1
if
torch
.
version
.
hip
:
atol
*=
2
current_platform
.
seed_everything
(
0
)
batch_size
=
4
tokens_per_seq
=
torch
.
randint
(
1
,
max_seq_len
+
1
,
(
batch_size
,),
device
=
device
)
total_tokens
=
int
(
tokens_per_seq
.
sum
().
item
())
num_accepted_tokens
=
torch
.
randint
(
0
,
max_seq_len
,
(
batch_size
,),
device
=
device
)
num_accepted_tokens
[
0
]
=
0
# Add edge-case of no accepted tokens
num_accepted_tokens
[
1
]
=
max_seq_len
# Add edge-case of all tokens accepted
cu_seqlens
=
torch
.
tensor
(
[
0
]
+
torch
.
cumsum
(
tokens_per_seq
,
dim
=
0
).
tolist
(),
dtype
=
torch
.
int32
,
device
=
device
,
)
total_state_slots
=
50
state
=
torch
.
randn
(
total_state_slots
,
dim
,
dstate
,
dtype
=
itype
,
device
=
device
)
state_batch_indices
=
torch
.
full
(
(
batch_size
,
max_seq_len
),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
device
)
initial_state_slots
=
torch
.
randint
(
0
,
15
,
(
batch_size
,),
device
=
device
,
dtype
=
torch
.
int32
)
for
seq_idx
in
range
(
batch_size
):
token_pos
=
max
(
num_accepted_tokens
[
seq_idx
].
item
()
-
1
,
0
)
state_batch_indices
[
seq_idx
,
token_pos
]
=
initial_state_slots
[
seq_idx
]
dst_state_batch_indices
=
torch
.
full
(
(
batch_size
,
max_seq_len
),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
device
)
slot_offset
=
15
dst_slots_map
=
{}
for
seq_idx
in
range
(
batch_size
):
for
token_idx
in
range
(
tokens_per_seq
[
seq_idx
].
item
()):
dst_state_batch_indices
[
seq_idx
,
token_idx
]
=
slot_offset
dst_slots_map
[(
seq_idx
,
token_idx
)]
=
slot_offset
slot_offset
+=
1
x
=
torch
.
randn
(
total_tokens
,
dim
,
device
=
device
,
dtype
=
itype
)
out
=
torch
.
empty_like
(
x
)
dt
=
torch
.
randn
(
total_tokens
,
dim
,
device
=
device
,
dtype
=
itype
)
dt_bias
=
torch
.
rand
(
dim
,
device
=
device
)
-
4.0
A
=
-
torch
.
rand
(
dim
,
dstate
,
device
=
device
)
-
1.0
B
=
torch
.
randn
(
total_tokens
,
dstate
,
device
=
device
)
C
=
torch
.
randn
(
total_tokens
,
dstate
,
device
=
device
)
D
=
torch
.
randn
(
dim
,
device
=
device
)
z
=
torch
.
randn_like
(
x
)
if
has_z
else
None
state_ref_intermediate
=
{}
for
seq_idx
in
range
(
batch_size
):
seq_start
=
cu_seqlens
[
seq_idx
].
item
()
seq_end
=
cu_seqlens
[
seq_idx
+
1
].
item
()
num_tokens
=
seq_end
-
seq_start
token_pos
=
max
(
num_accepted_tokens
[
seq_idx
].
item
()
-
1
,
0
)
initial_slot
=
state_batch_indices
[
seq_idx
,
token_pos
].
item
()
state_seq
=
state
[
initial_slot
:
initial_slot
+
1
].
clone
()
for
token_idx
in
range
(
num_tokens
):
global_idx
=
seq_start
+
token_idx
selective_state_update_ref
(
state_seq
,
x
[
global_idx
:
global_idx
+
1
],
dt
[
global_idx
:
global_idx
+
1
],
A
,
B
[
global_idx
:
global_idx
+
1
],
C
[
global_idx
:
global_idx
+
1
],
D
=
D
,
z
=
z
[
global_idx
:
global_idx
+
1
]
if
has_z
else
None
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
)
state_ref_intermediate
[(
seq_idx
,
token_idx
)]
=
state_seq
.
clone
()
selective_state_update
(
state
,
x
,
dt
,
A
,
B
,
C
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
out
=
out
,
cu_seqlens
=
cu_seqlens
,
state_batch_indices
=
state_batch_indices
,
dst_state_batch_indices
=
dst_state_batch_indices
,
num_accepted_tokens
=
num_accepted_tokens
,
pad_slot_id
=
PAD_SLOT_ID
,
)
for
seq_idx
in
range
(
batch_size
):
num_tokens
=
tokens_per_seq
[
seq_idx
].
item
()
for
token_idx
in
range
(
num_tokens
):
dst_slot
=
dst_slots_map
[(
seq_idx
,
token_idx
)]
state_ref
=
state_ref_intermediate
[(
seq_idx
,
token_idx
)].
squeeze
(
0
)
assert
torch
.
allclose
(
state
[
dst_slot
],
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
tests/kernels/moe/modular_kernel_tools/mk_objects.py
View file @
8d75f22e
...
...
@@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
BatchedDeepGemmExperts
,
)
from
vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe
import
(
BatchedTritonOrDeepGemmExperts
,
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
...
...
@@ -286,16 +283,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
needs_matching_quant
=
False
,
needs_deep_gemm
=
True
,
)
register_experts
(
BatchedTritonOrDeepGemmExperts
,
batched_format
,
common_float_and_int_types
,
blocked_quantization_support
=
True
,
supports_chunking
=
False
,
supports_expert_map
=
False
,
needs_matching_quant
=
True
,
needs_deep_gemm
=
True
,
)
register_experts
(
TritonOrDeepGemmExperts
,
standard_format
,
...
...
@@ -457,10 +444,6 @@ def make_fused_experts(
kwargs
=
batch_kwargs
|
quant_kwargs
print
(
f
"Making BatchedTritonExperts
{
kwargs
}
..."
)
experts
=
BatchedTritonExperts
(
**
kwargs
)
elif
fused_experts_type
==
BatchedTritonOrDeepGemmExperts
:
kwargs
=
batch_kwargs
|
quant_kwargs
|
deepgemm_kwargs
print
(
f
"Making BatchedTritonOrDeepGemmExperts
{
kwargs
}
..."
)
experts
=
BatchedTritonOrDeepGemmExperts
(
**
kwargs
)
elif
fused_experts_type
==
DeepGemmExperts
:
print
(
f
"Making DeepGemmExperts
{
quant_config
}
..."
)
experts
=
DeepGemmExperts
(
quant_config
)
...
...
tests/kernels/moe/test_moe.py
View file @
8d75f22e
...
...
@@ -955,9 +955,22 @@ def test_fused_marlin_moe_with_bias(m):
torch
.
testing
.
assert_close
(
marlin_output
,
torch_output
,
atol
=
5e-2
,
rtol
=
0
)
def
test_moe_align_block_size_opcheck
():
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
[
1
,
2
])
def
test_moe_align_block_size_opcheck
(
ep_size
):
num_experts
=
4
block_size
=
4
expert_map
=
None
if
ep_size
!=
1
:
local_num_experts
=
num_experts
//
ep_size
expert_ids
=
torch
.
randint
(
0
,
num_experts
,
(
local_num_experts
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
expert_map
=
torch
.
full
((
num_experts
,),
-
1
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
expert_map
[
expert_ids
]
=
torch
.
arange
(
local_num_experts
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
topk_ids
=
torch
.
randint
(
0
,
num_experts
,
(
3
,
4
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
...
...
@@ -980,6 +993,7 @@ def test_moe_align_block_size_opcheck():
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
expert_map
,
),
)
...
...
tests/kernels/moe/test_moe_align_block_size.py
View file @
8d75f22e
...
...
@@ -106,6 +106,8 @@ def torch_moe_align_block_size(
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
if
pad_sorted_ids
:
max_num_tokens_padded
=
round_up
(
max_num_tokens_padded
,
block_size
)
if
topk_ids
.
numel
()
<
num_experts
:
max_num_tokens_padded
=
topk_ids
.
numel
()
*
block_size
flattened_token_indices
=
torch
.
arange
(
topk_ids
.
numel
(),
device
=
topk_ids
.
device
,
dtype
=
torch
.
int32
...
...
@@ -126,6 +128,8 @@ def torch_moe_align_block_size(
)
for
expert_id
in
range
(
num_experts
):
original_count
=
expert_token_counts
[
expert_id
]
if
expert_map
is
not
None
and
expert_map
[
expert_id
]
==
-
1
:
continue
if
original_count
>
0
:
expert_padded_counts
[
expert_id
]
=
(
(
original_count
+
block_size
-
1
)
//
block_size
...
...
@@ -143,6 +147,9 @@ def torch_moe_align_block_size(
current_pos
=
0
current_block
=
0
for
expert_id
in
range
(
num_experts
):
if
expert_map
is
not
None
and
expert_map
[
expert_id
]
==
-
1
:
continue
expert_mask
=
sorted_expert_ids
==
expert_id
expert_tokens
=
sorted_token_indices
[
expert_mask
]
num_expert_tokens
=
expert_tokens
.
shape
[
0
]
...
...
@@ -153,7 +160,13 @@ def torch_moe_align_block_size(
)
expert_blocks_needed
=
expert_padded_counts
[
expert_id
]
//
block_size
expert_ids
[
current_block
:
current_block
+
expert_blocks_needed
]
=
expert_id
expert_id_new
=
expert_id
if
expert_map
is
not
None
:
expert_id_new
=
expert_map
[
expert_id
]
expert_ids
[
current_block
:
current_block
+
expert_blocks_needed
]
=
(
expert_id_new
)
current_pos
+=
expert_padded_counts
[
expert_id
]
current_block
+=
expert_blocks_needed
...
...
@@ -163,8 +176,6 @@ def torch_moe_align_block_size(
[
total_padded_tokens
],
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
if
expert_map
is
not
None
:
expert_ids
=
expert_map
[
expert_ids
]
return
sorted_token_ids
,
expert_ids
,
num_tokens_post_pad
...
...
@@ -229,9 +240,9 @@ def test_moe_align_block_size(
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
16
,
32
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
16
,
32
,
2048
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
8
,
64
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
64
])
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Skip for rocm"
)
def
test_moe_align_block_size_with_expert_map
(
...
...
@@ -253,6 +264,7 @@ def test_moe_align_block_size_with_expert_map(
block_size
=
block_size
,
num_experts
=
num_experts
,
expert_map
=
expert_map
,
ignore_invalid_experts
=
True
,
)
golden_sorted_ids
,
golden_expert_ids
,
golden_num_tokens
=
(
torch_moe_align_block_size
(
...
...
tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
_per_token_group_quant_fp8_colmajor
,
silu_mul_per_token_group_quant_fp8_colmajor
,
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
triton
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
FLOAT8_DTYPE
=
torch
.
float8_e4m3fn
GROUP_SIZE
=
128
def
reference_quant
(
x
:
torch
.
Tensor
,
use_ue8m0
:
bool
):
"""
Reference triton quant kernel from,
vllm.model_executor.layers.quantization.utils.fp8_utils
"""
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
FLOAT8_DTYPE
)
# Allocate the scale tensor in column-major format.
shape
=
(
x
.
shape
[
-
1
]
//
GROUP_SIZE
,)
+
x
.
shape
[:
-
1
]
x_s
=
torch
.
empty
(
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float32
).
permute
(
-
1
,
-
2
)
M
=
x
.
numel
()
//
GROUP_SIZE
N
=
GROUP_SIZE
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
finfo
=
torch
.
finfo
(
FLOAT8_DTYPE
)
fp8_min
=
finfo
.
min
fp8_max
=
finfo
.
max
_per_token_group_quant_fp8_colmajor
[(
M
,)](
x
,
x_q
,
x_s
,
GROUP_SIZE
,
x
.
shape
[
1
],
x
.
stride
(
0
),
x_s
.
stride
(
1
),
eps
=
1e-10
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
use_ue8m0
=
use_ue8m0
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
def
reference
(
x
:
torch
.
Tensor
,
use_ue8m0
:
bool
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
T
,
N
=
x
.
size
()
ref_act_out
=
torch
.
empty
((
T
,
N
//
2
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
torch
.
ops
.
_C
.
silu_and_mul
(
ref_act_out
,
x
)
return
reference_quant
(
ref_act_out
,
use_ue8m0
)
@
pytest
.
mark
.
parametrize
(
"T"
,
[
128
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
*
2
,
256
*
2
,
768
*
2
,
2048
*
2
,
7168
*
2
])
def
test_silu_mul_fp8_quant_deep_gemm
(
T
:
int
,
N
:
int
):
current_platform
.
seed_everything
(
42
)
input
=
torch
.
rand
((
T
,
N
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
use_ue8m0
=
is_deep_gemm_e8m0_used
()
# Test
output
,
output_scales
=
silu_mul_per_token_group_quant_fp8_colmajor
(
input
,
use_ue8m0
=
use_ue8m0
)
# Reference
ref_output
,
ref_output_scales
=
reference
(
input
,
use_ue8m0
)
torch
.
testing
.
assert_close
(
output
.
to
(
torch
.
float32
),
ref_output
.
to
(
torch
.
float32
))
torch
.
testing
.
assert_close
(
output_scales
,
ref_output_scales
)
tests/kernels/quant_utils.py
View file @
8d75f22e
...
...
@@ -103,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant(
.
clamp
(
fp8_traits_min
,
fp8_traits_max
)
.
to
(
FP8_DTYPE
)
)
return
ref_out
,
ref_scale
.
view
(
(
1
,
1
)
)
return
ref_out
,
ref_scale
.
view
(
1
)
def
native_w8a8_block_matmul
(
...
...
tests/kernels/quantization/test_block_fp8.py
View file @
8d75f22e
...
...
@@ -54,6 +54,10 @@ def setup_cuda():
torch
.
set_default_device
(
"cuda"
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_fp8_fnuz
(),
reason
=
"This platform supports e4m3fnuz, not e4m3fn."
,
)
@
pytest
.
mark
.
parametrize
(
"num_tokens,d,dtype,group_size,seed"
,
itertools
.
product
(
NUM_TOKENS
,
D
,
DTYPES
,
GROUP_SIZE
,
SEEDS
),
...
...
@@ -78,14 +82,14 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
def
test_w8a8_block_fp8_matmul
(
M
,
N
,
K
,
block_size
,
out_dtype
,
seed
):
torch
.
manual_seed
(
seed
)
factor_for_scale
=
1e-2
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_info
=
torch
.
finfo
(
current_platform
.
fp8_dtype
()
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
A_fp32
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
A_fp8
=
A_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
A_fp8
=
A_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
current_platform
.
fp8_dtype
()
)
B_fp32
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
)
-
0.5
)
*
2
*
fp8_max
B_fp8
=
B_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
B_fp8
=
B_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
current_platform
.
fp8_dtype
()
)
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
...
...
@@ -103,6 +107,9 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
assert
rel_diff
<
0.001
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"CUTLASS only supported on CUDA platform."
)
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_cutlass_matmul
():
# Test simple case where weight.shape % 128 != 0,
...
...
@@ -151,6 +158,10 @@ def test_w8a8_block_fp8_cutlass_matmul():
assert
rel_diff
<
0.001
@
pytest
.
mark
.
skipif
(
current_platform
.
is_fp8_fnuz
(),
reason
=
"This platform supports e4m3fnuz, not e4m3fn."
,
)
@
pytest
.
mark
.
parametrize
(
"M,N,K,block_size,out_dtype,seed"
,
itertools
.
product
(
M
,
N
,
K
,
BLOCK_SIZE
,
OUT_DTYPES
,
SEEDS
),
...
...
tests/kernels/quantization/test_cutlass_scaled_mm.py
View file @
8d75f22e
...
...
@@ -15,6 +15,9 @@ from vllm import _custom_ops as ops
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
if
not
current_platform
.
is_cuda
():
pytest
.
skip
(
"These tests use CUTLASS which requires CUDA"
,
allow_module_level
=
True
)
MNK_FACTORS
=
[
(
1
,
256
,
128
),
(
1
,
16384
,
1024
),
...
...
tests/kernels/quantization/test_cutlass_w4a8.py
View file @
8d75f22e
...
...
@@ -12,12 +12,18 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
convert_packed_uint4b8_to_signed_int4_inplace
,
pack_cols
,
pack_rows
,
quantize_weights
,
unpack_quantized_values_into_int32
,
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
if
not
current_platform
.
is_cuda
():
pytest
.
skip
(
"These tests use CUTLASS which requires CUDA"
,
allow_module_level
=
True
)
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# unit tests to a common utility function. Currently the use of
# `is_quant_method_supported` conflates kernels with quantization methods
...
...
@@ -167,8 +173,7 @@ def create_test_tensors(
# for the practical use case we need per-tok scales for fp8 activations
w_tok_s
=
torch
.
randn
((
m
,),
device
=
"cuda"
,
dtype
=
types
.
token_scale_type
)
# weights are already per-group quantized, use placeholder here
w_ch_s
=
torch
.
ones
((
n
,),
device
=
"cuda"
,
dtype
=
types
.
channel_scale_type
)
w_ch_s
=
torch
.
randn
((
n
,),
device
=
"cuda"
,
dtype
=
types
.
channel_scale_type
)
return
Tensors
(
w_ref
=
w_ref
,
...
...
@@ -211,7 +216,7 @@ def mm_test_helper(
print
(
output_ref
)
torch
.
testing
.
assert_close
(
output
,
output_ref
.
to
(
output
.
dtype
),
rtol
=
1e-
3
,
atol
=
1e-
3
output
,
output_ref
.
to
(
output
.
dtype
),
rtol
=
1e-
2
,
atol
=
1e-
2
)
...
...
@@ -257,7 +262,7 @@ def test_w4a8_cuda_graph():
)
w_tok_s
=
torch
.
randn
((
m
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w_ch_s
=
torch
.
ones
((
n
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w_ch_s
=
torch
.
randn
((
n
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
# Construct a trivial model with a single layer that calls the kernel
model
=
W4A8Layer
(
...
...
@@ -287,4 +292,38 @@ def test_w4a8_cuda_graph():
output
.
zero_
()
g
.
replay
()
torch
.
testing
.
assert_close
(
output
,
output_ref
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
output
,
output_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
reason
=
"CUTLASS W4A8 is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
MNK_SHAPES
)
def
test_convert_packed_uint4b8_to_signed_int4_inplace
(
shape
):
"""
The W4A16 checkpoints encode the weights as int4b8 packed to int32.
The CUTLASS kernels expect signed int4 packed to int32.
This tests checks that the runtime int4b8 -> signed int4 conversion
matches the offline conversion step exactly.
"""
_
,
N
,
K
=
shape
# random weights packed to int32
t
=
torch
.
randint
(
low
=
torch
.
iinfo
(
torch
.
int32
).
min
,
high
=
torch
.
iinfo
(
torch
.
int32
).
max
+
1
,
size
=
(
N
,
K
//
8
),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
# compute reference
unpacked
=
unpack_quantized_values_into_int32
(
t
.
clone
(),
scalar_types
.
uint4b8
,
packed_dim
=
1
)
unpacked
=
unpacked
-
8
# int4b8 -> signed int4
ref
=
pack_cols
(
unpacked
&
0x0F
,
4
,
*
unpacked
.
shape
)
out
=
convert_packed_uint4b8_to_signed_int4_inplace
(
t
.
clone
())
assert
torch
.
equal
(
ref
,
out
)
assert
not
torch
.
equal
(
ref
,
t
)
tests/kernels/quantization/test_cutlass_w4a8_moe.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the CUTLASS-based W4A8 grouped GEMM kernel and the full MoE layer.
"""
import
random
from
dataclasses
import
dataclass
import
pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
pack_rows
,
quantize_weights
,
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
IS_SUPPORTED_BY_GPU
=
current_platform
.
get_device_capability
()[
0
]
>=
9
def
to_fp8
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
cutlass_quantize
(
atype
:
torch
.
dtype
,
w
:
torch
.
Tensor
,
wtype
:
ScalarType
,
stype
:
torch
.
dtype
|
None
,
group_size
:
int
|
None
,
zero_points
:
bool
=
False
,
):
"""
Quantize weights into W4 and compute reference dequantized weights.
Encoding/reordering of weights and packing of scales is deferred
until after all experts are combined.
"""
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
w_ref
,
w_q
,
w_s
,
w_zp
=
quantize_weights
(
w
,
wtype
,
group_size
=
group_size
,
zero_points
=
zero_points
)
# Since scales are later cast to fp8, recompute w_ref in atype here.
w_ref
=
(
w_q
.
to
(
torch
.
float32
)
*
w_s
.
to
(
atype
).
to
(
torch
.
float32
).
repeat_interleave
(
group_size
,
dim
=
0
)
).
to
(
atype
)
# Bit mask prevents sign extension of int4 when packing.
w_q
=
pack_rows
(
w_q
&
0x0F
,
wtype
.
size_bits
,
*
w_q
.
shape
)
# Make weights row-major (N, K).
w_q
=
w_q
.
t
().
contiguous
()
return
w_ref
,
w_q
,
w_s
.
to
(
atype
),
w_zp
def
cutlass_preprocess
(
w_q_experts
:
list
[
torch
.
Tensor
],
w_s_experts
:
list
[
torch
.
Tensor
]
):
"""
Reorder/encode expert weights and pack scales.
Returns:
w_q_packed: Packed/encoded int4 weights for all experts.
w_s_packed: Packed fp8 scales for all experts.
packed_layout: Layout/stride metadata for grouped GEMM.
"""
w_s_packed
=
ops
.
cutlass_pack_scale_fp8
(
torch
.
stack
(
w_s_experts
))
w_q_packed
,
packed_layout
=
ops
.
cutlass_encode_and_reorder_int4b_grouped
(
torch
.
stack
(
w_q_experts
)
)
# expects dim 3
return
w_q_packed
,
w_s_packed
,
packed_layout
GROUP_SIZE
=
128
# (num_experts, N, K)
TEST_SHAPES
=
[
(
8
,
512
,
2048
),
(
8
,
2048
,
2048
),
(
64
,
512
,
1024
),
(
64
,
2048
,
2048
),
(
4
,
2048
,
768
),
(
8
,
768
,
2048
),
(
64
,
1536
,
2048
),
(
128
,
8192
,
4096
),
# test overflow int32
]
ALIGNMENT
=
16
# torch._scaled_mm alignment for M, needed for reference check
@
dataclass
class
MoETestSetup
:
num_experts
:
int
K
:
int
N
:
int
Ms
:
list
[
int
]
M_full
:
int
a
:
torch
.
Tensor
a_ref
:
torch
.
Tensor
a_strides
:
torch
.
Tensor
out
:
torch
.
Tensor
c_strides
:
torch
.
Tensor
per_tok_scales
:
torch
.
Tensor
per_chan_scales
:
torch
.
Tensor
w_refs
:
list
[
torch
.
Tensor
]
w_q_packed
:
torch
.
Tensor
w_s_packed
:
torch
.
Tensor
problem_sizes
:
torch
.
Tensor
expert_offsets
:
torch
.
Tensor
b_strides
:
torch
.
Tensor
group_scale_strides
:
torch
.
Tensor
def
make_moe_test_setup
(
num_experts
:
int
,
K
:
int
,
N
:
int
,
*
,
alignment
:
int
=
ALIGNMENT
,
max_blocks
:
int
=
64
,
device
:
str
=
"cuda"
,
random_zero
:
bool
=
False
,
)
->
MoETestSetup
:
"""Create a full set of tensors for testing cutlass_w4a8_moe_mm."""
assert
K
%
GROUP_SIZE
==
0
# Token counts per expert (multiples of `alignment`).
Ms
=
[
alignment
*
random
.
randint
(
1
,
max_blocks
)
for
_
in
range
(
num_experts
)]
# set random experts to 0 tokens
if
random_zero
and
num_experts
>
1
:
num_zero
=
max
(
1
,
num_experts
//
8
)
zero_indices
=
random
.
sample
(
range
(
num_experts
),
k
=
num_zero
)
for
idx
in
zero_indices
:
Ms
[
idx
]
=
0
M_full
=
sum
(
Ms
)
assert
M_full
>
0
# Activations.
a
=
to_fp8
(
torch
.
randn
((
M_full
,
K
),
device
=
device
))
a_ref
=
a
.
to
(
torch
.
float32
)
a_strides
=
torch
.
full
((
num_experts
,),
K
,
dtype
=
torch
.
int64
,
device
=
device
)
# Output buffer.
out
=
torch
.
empty
((
M_full
,
N
),
dtype
=
torch
.
bfloat16
,
device
=
device
)
c_strides
=
torch
.
full
((
num_experts
,),
N
,
dtype
=
torch
.
int64
,
device
=
device
)
# Channel/token scales.
per_tok_scales
=
torch
.
randn
((
M_full
,
1
),
dtype
=
torch
.
float32
,
device
=
device
)
per_chan_scales
=
torch
.
randn
(
(
num_experts
,
N
,
1
),
dtype
=
torch
.
float32
,
device
=
device
)
# Expert weights and scales.
wtype
=
scalar_types
.
int4
atype
=
stype
=
torch
.
float8_e4m3fn
w_refs
,
w_qs
,
w_ss
=
[],
[],
[]
for
_
in
range
(
num_experts
):
b
=
to_fp8
(
torch
.
randn
((
K
,
N
),
device
=
device
))
w_ref
,
w_q
,
w_s
,
_
=
cutlass_quantize
(
atype
,
b
.
to
(
torch
.
float16
),
wtype
,
stype
,
GROUP_SIZE
,
zero_points
=
False
)
w_refs
.
append
(
w_ref
)
w_qs
.
append
(
w_q
)
w_ss
.
append
(
w_s
)
w_q_packed
,
w_s_packed
,
packed_layout
=
cutlass_preprocess
(
w_qs
,
w_ss
)
problem_sizes
=
torch
.
tensor
(
[[
N
,
M
,
K
]
for
M
in
Ms
],
dtype
=
torch
.
int32
,
device
=
device
)
expert_offsets
=
torch
.
cat
(
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int64
),
torch
.
cumsum
(
torch
.
tensor
(
Ms
,
dtype
=
torch
.
int64
),
dim
=
0
)[:
-
1
],
]
).
to
(
device
=
device
)
# B strides and group scale strides.
b_strides
=
packed_layout
group_scale_strides
=
torch
.
zeros
(
(
num_experts
,
2
),
dtype
=
torch
.
int64
,
device
=
device
)
group_scale_strides
[:,
0
]
=
N
return
MoETestSetup
(
num_experts
=
num_experts
,
K
=
K
,
N
=
N
,
Ms
=
Ms
,
M_full
=
M_full
,
a
=
a
,
a_ref
=
a_ref
,
a_strides
=
a_strides
,
out
=
out
,
c_strides
=
c_strides
,
per_tok_scales
=
per_tok_scales
,
per_chan_scales
=
per_chan_scales
,
w_refs
=
w_refs
,
w_q_packed
=
w_q_packed
,
w_s_packed
=
w_s_packed
,
problem_sizes
=
problem_sizes
,
expert_offsets
=
expert_offsets
,
b_strides
=
b_strides
,
group_scale_strides
=
group_scale_strides
,
)
def
compute_moe_reference_output
(
setup
:
MoETestSetup
)
->
torch
.
Tensor
:
"""Compute reference output using torch._scaled_mm per expert."""
out_ref
=
torch
.
empty_like
(
setup
.
out
)
ends
=
torch
.
cumsum
(
torch
.
tensor
(
setup
.
Ms
),
0
).
tolist
()
starts
=
setup
.
expert_offsets
.
cpu
().
tolist
()
for
i
in
range
(
setup
.
num_experts
):
start
,
end
=
starts
[
i
],
ends
[
i
]
if
start
==
end
:
continue
out_ref_i
=
torch
.
_scaled_mm
(
setup
.
a_ref
[
start
:
end
].
to
(
torch
.
float8_e4m3fn
),
setup
.
w_refs
[
i
].
to
(
torch
.
float8_e4m3fn
).
t
().
contiguous
().
t
(),
setup
.
per_tok_scales
[
start
:
end
],
# (M, 1)
setup
.
per_chan_scales
[
i
].
reshape
(
1
,
-
1
),
# (1, N)
out_dtype
=
torch
.
bfloat16
,
use_fast_accum
=
True
,
)
out_ref
[
start
:
end
]
=
out_ref_i
return
out_ref
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
reason
=
"W4A8 Grouped GEMM is not supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
TEST_SHAPES
)
@
pytest
.
mark
.
parametrize
(
"random_zero"
,
[
True
,
False
])
def
test_cutlass_w4a8_moe_mm_end_to_end
(
shape
,
random_zero
):
num_experts
,
N
,
K
=
shape
current_platform
.
seed_everything
(
42
)
setup
=
make_moe_test_setup
(
num_experts
=
num_experts
,
K
=
K
,
N
=
N
,
max_blocks
=
64
,
random_zero
=
random_zero
)
ops
.
cutlass_w4a8_moe_mm
(
setup
.
out
,
setup
.
a
,
setup
.
w_q_packed
,
setup
.
per_tok_scales
,
setup
.
per_chan_scales
,
setup
.
w_s_packed
,
GROUP_SIZE
,
setup
.
expert_offsets
,
setup
.
problem_sizes
,
setup
.
a_strides
,
setup
.
b_strides
,
setup
.
c_strides
,
setup
.
group_scale_strides
,
)
torch
.
cuda
.
synchronize
()
out_ref
=
compute_moe_reference_output
(
setup
)
torch
.
testing
.
assert_close
(
setup
.
out
,
out_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
class
W4A8MoELayer
(
torch
.
nn
.
Module
):
"""
Minimal wrapper module to test cuda graphs
"""
def
__init__
(
self
,
setup
:
MoETestSetup
):
super
().
__init__
()
self
.
setup
=
setup
def
forward
(
self
,
a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
s
=
self
.
setup
ops
.
cutlass_w4a8_moe_mm
(
s
.
out
,
a
,
s
.
w_q_packed
,
s
.
per_tok_scales
,
s
.
per_chan_scales
,
s
.
w_s_packed
,
GROUP_SIZE
,
s
.
expert_offsets
,
s
.
problem_sizes
,
s
.
a_strides
,
s
.
b_strides
,
s
.
c_strides
,
s
.
group_scale_strides
,
)
return
s
.
out
@
pytest
.
mark
.
skipif
(
not
IS_SUPPORTED_BY_GPU
,
reason
=
"W4A8 Grouped GEMM is not supported on this GPU type."
,
)
def
test_cutlass_w4a8_moe_mm_cuda_graph
():
current_platform
.
seed_everything
(
42
)
# Fixed config for CUDA graph test (single parameter point).
num_experts
=
8
K
=
512
N
=
2048
setup
=
make_moe_test_setup
(
num_experts
=
num_experts
,
K
=
K
,
N
=
N
,
max_blocks
=
32
,
)
# Construct model that calls the grouped GEMM kernel.
model
=
W4A8MoELayer
(
setup
)
# Build reference output once.
out_ref
=
compute_moe_reference_output
(
setup
)
# Capture and run the model in a CUDA graph.
a_static
=
setup
.
a
.
clone
()
# static input tensor for graph replay
stream
=
torch
.
cuda
.
Stream
()
with
torch
.
cuda
.
stream
(
stream
):
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
out_static
=
model
(
a_static
)
out_static
.
zero_
()
g
.
replay
()
torch
.
testing
.
assert_close
(
out_static
,
out_ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
tests/kernels/quantization/test_hadacore.py
View file @
8d75f22e
...
...
@@ -8,6 +8,13 @@ import torch
from
compressed_tensors.transform
import
deterministic_hadamard_matrix
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
if
current_platform
.
is_rocm
():
pytest
.
skip
(
"These tests require hadacore_transform, not supported on ROCm."
,
allow_module_level
=
True
,
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
...
...
tests/kernels/quantization/test_machete_mm.py
View file @
8d75f22e
...
...
@@ -23,6 +23,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
if
current_platform
.
is_rocm
():
pytest
.
skip
(
"These tests require machete_prepack_B, not supported on ROCm."
,
allow_module_level
=
True
,
)
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)]
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
...
...
tests/kernels/quantization/test_marlin_gemm.py
View file @
8d75f22e
...
...
@@ -56,6 +56,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
if
current_platform
.
is_rocm
():
pytest
.
skip
(
"These tests require gptq_marlin_repack,"
"marlin_int4_fp8_preprocess, gptq_marlin_24_gemm,"
"or gptq_marlin_gemm which are not supported on ROCm."
,
allow_module_level
=
True
,
)
ACT_ORDER_OPTS
=
[
False
,
True
]
K_FULL_OPTS
=
[
False
,
True
]
USE_ATOMIC_ADD_OPTS
=
[
False
,
True
]
...
...
tests/kernels/test_top_k_per_row.py
View file @
8d75f22e
...
...
@@ -9,23 +9,45 @@ from vllm.platforms import current_platform
# Test parameters
NUM_ROWS
=
[
1
,
32
,
2050
]
TOP_K_VALUES
=
[
2048
]
BATCH_SIZE
=
[
1
,
2
,
4
,
2048
,
4096
]
NEXT_N
=
[
1
,
2
,
4
,
8
]
TOP_K_VALUES
=
[
2048
,
3000
]
BATCH_SIZE
=
[
1
,
2
,
2048
]
NEXT_N
=
[
1
,
8
]
DATA_GENERATION
=
[
"random"
,
"10LSBits"
]
def
create_random_logits
(
row_starts
:
torch
.
Tensor
,
row_ends
:
torch
.
Tensor
,
vocab_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
data_generation
:
str
,
)
->
torch
.
Tensor
:
"""Create random logits tensor for testing."""
torch
.
manual_seed
(
seed
)
np
.
random
.
seed
(
seed
)
# Generate logits with some structure to make testing more meaningful
logits
=
torch
.
randn
(
row_starts
.
shape
[
0
],
max
(
row_ends
),
dtype
=
dtype
,
device
=
"cuda"
)
if
data_generation
==
"random"
:
logits
=
torch
.
randn
(
row_starts
.
shape
[
0
],
max
(
row_ends
),
dtype
=
dtype
,
device
=
"cuda"
)
elif
data_generation
==
"10LSBits"
:
top_22_bits_mask
=
0xFFFFFC00
last_10_bits_mask
=
0x000003FF
fixed_top_22_bits
=
0x3F900000
# Generate random bits for the last 10 bits
random_bottom_bits
=
torch
.
randint
(
0
,
2
**
10
,
(
row_starts
.
shape
[
0
],
max
(
row_ends
)),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
# Combine: fixed top 22 bits with random last 10 bits
logits_bits
=
(
fixed_top_22_bits
&
top_22_bits_mask
)
|
(
random_bottom_bits
&
last_10_bits_mask
)
logits
=
logits_bits
.
view
(
dtype
)
for
i
,
end
in
enumerate
(
row_ends
):
logits
[
i
,
end
:]
=
float
(
"-inf"
)
return
logits
...
...
@@ -113,13 +135,13 @@ def test_top_k_per_row(
# Create test data
vocab_size
=
20000
row_starts
,
row_ends
=
create_row_boundaries
(
num_rows
,
vocab_size
)
logits
=
create_random_logits
(
row_starts
,
row_ends
,
vocab_size
,
torch
.
float32
,
42
)
logits
=
create_random_logits
(
row_starts
,
row_ends
,
torch
.
float32
,
42
,
"random"
)
# Create output tensors
indices
=
torch
.
empty
((
num_rows
,
top_k
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# Run CUDA implementation
torch
.
ops
.
_C
.
top_k_per_row
(
torch
.
ops
.
_C
.
top_k_per_row
_prefill
(
logits
,
row_starts
,
row_ends
,
...
...
@@ -127,6 +149,7 @@ def test_top_k_per_row(
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
top_k
,
)
# Run reference implementation
...
...
@@ -139,27 +162,23 @@ def test_top_k_per_row(
# Compare results
assert
compare_top_k_results
(
logits
,
indices
,
torch_indices
,
row_starts
,
row_ends
,
top_k
),
"CUDA top_k_per_row results don't match torch.topk"
),
"CUDA top_k_per_row
_prefill
results don't match torch.topk"
@
pytest
.
mark
.
parametrize
(
"top_k"
,
TOP_K_VALUES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZE
)
@
pytest
.
mark
.
parametrize
(
"next_n"
,
NEXT_N
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test requires CUDA"
)
@
torch
.
inference_mode
()
def
test_top_k_per_row_decode
(
def
_run_top_k_per_row_decode_test
(
top_k
:
int
,
batch_size
:
int
,
next_n
:
int
,
vocab_size
:
int
,
data_generation
:
str
,
)
->
None
:
"""
Test top_k_per_row with seq_lens tensor
.
Helper function to run top_k_per_row_decode test with given parameters
.
"""
torch
.
set_default_device
(
"cuda:0"
)
# Create test data
num_rows
=
batch_size
*
next_n
vocab_size
=
20000
seq_lens
=
torch
.
randint
(
vocab_size
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
...
@@ -167,7 +186,9 @@ def test_top_k_per_row_decode(
row_indices
=
torch
.
arange
(
num_rows
,
device
=
"cuda"
)
//
next_n
next_n_offset
=
torch
.
arange
(
num_rows
,
device
=
"cuda"
)
%
next_n
row_ends
=
seq_lens
[
row_indices
]
-
next_n
+
next_n_offset
+
1
logits
=
create_random_logits
(
row_starts
,
row_ends
,
vocab_size
,
torch
.
float32
,
42
)
logits
=
create_random_logits
(
row_starts
,
row_ends
,
torch
.
float32
,
42
,
data_generation
)
# Create output tensors
indices
=
torch
.
empty
((
num_rows
,
top_k
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
...
@@ -181,6 +202,7 @@ def test_top_k_per_row_decode(
num_rows
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
top_k
,
)
torch
.
cuda
.
synchronize
()
...
...
@@ -195,4 +217,41 @@ def test_top_k_per_row_decode(
# Compare results
assert
compare_top_k_results
(
logits
,
indices
,
torch_indices
,
row_starts
,
row_ends
,
top_k
),
"CUDA top_k_per_row results don't match torch.topk"
),
"CUDA top_k_per_row_decode results don't match torch.topk"
@
pytest
.
mark
.
parametrize
(
"top_k"
,
TOP_K_VALUES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZE
)
@
pytest
.
mark
.
parametrize
(
"next_n"
,
NEXT_N
)
@
pytest
.
mark
.
parametrize
(
"data_generation"
,
DATA_GENERATION
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test requires CUDA"
)
@
torch
.
inference_mode
()
def
test_top_k_per_row_decode
(
top_k
:
int
,
batch_size
:
int
,
next_n
:
int
,
data_generation
:
str
,
)
->
None
:
"""
Test top_k_per_row with seq_lens tensor.
"""
vocab_size
=
20000
_run_top_k_per_row_decode_test
(
top_k
,
batch_size
,
next_n
,
vocab_size
,
data_generation
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test requires CUDA"
)
@
torch
.
inference_mode
()
def
test_top_k_per_row_decode_large_vocab_size
()
->
None
:
"""
Test top_k_per_row_decode with large vocabulary size.
"""
top_k
=
2048
batch_size
=
2
next_n
=
2
vocab_size
=
300000
data_generation
=
"random"
_run_top_k_per_row_decode_test
(
top_k
,
batch_size
,
next_n
,
vocab_size
,
data_generation
)
tests/kv_transfer/test_lookup_buffer.py
deleted
100644 → 0
View file @
ce888aa4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
random
import
torch
from
tqdm
import
tqdm
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer
import
SimpleBuffer
from
vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe
import
PyNcclPipe
# TODO: the test depends on a lot of fields in the current implementation.
# We should have standard interface instead direct field access
def
test_run
(
my_rank
,
buffer
,
device
):
# buffer should be empty in the beginning
if
my_rank
==
0
:
assert
buffer
.
buffer_size
==
0
assert
len
(
buffer
.
buffer
)
==
0
print
(
f
"My rank:
{
my_rank
}
, device:
{
device
}
"
)
# insert
tokens
=
torch
.
tensor
([
1
,
2
,
3
]).
to
(
device
)
roi
=
tokens
>
0
if
my_rank
==
0
:
key
=
2.0
*
torch
.
ones
([
5
,
6
]).
to
(
device
)
value
=
3.0
*
torch
.
ones
([
5
,
6
]).
to
(
device
)
placeholder
=
torch
.
tensor
([
1
]).
to
(
device
)
buffer
.
insert
(
tokens
,
roi
,
key
,
value
,
placeholder
)
torch
.
distributed
.
barrier
()
# drop_select
if
my_rank
==
1
:
tok
,
roi_
,
key
,
value
,
hidden
=
buffer
.
drop_select
(
tokens
,
roi
)
assert
torch
.
allclose
(
tokens
,
tok
)
assert
torch
.
allclose
(
roi
,
roi_
)
assert
torch
.
allclose
(
key
,
2.0
*
torch
.
ones
([
5
,
6
],
device
=
device
))
assert
torch
.
allclose
(
value
,
3.0
*
torch
.
ones
([
5
,
6
],
device
=
device
))
torch
.
distributed
.
barrier
()
if
my_rank
==
0
:
assert
buffer
.
buffer_size
==
0
assert
len
(
buffer
.
buffer
)
==
0
print
(
f
"My rank:
{
my_rank
}
, Test run passed!"
)
def
stress_test
(
my_rank
,
buf
,
device
):
torch
.
distributed
.
barrier
()
torch
.
manual_seed
(
100
)
reqs
=
[
(
torch
.
rand
(
100
).
to
(
device
),
# tokens
torch
.
ones
(
100
).
bool
().
to
(
device
),
# roi
torch
.
rand
(
100
).
to
(
device
),
# key
torch
.
rand
(
100
).
to
(
device
),
# value
torch
.
rand
(
100
).
to
(
device
),
# hidden
)
for
i
in
tqdm
(
range
(
200
))
]
random
.
seed
(
my_rank
)
random
.
shuffle
(
reqs
)
torch
.
distributed
.
barrier
()
n
=
0
# the buffer size can only store 100 reqs
# so the sender will occasionally block to wait for the receiver.
for
req
in
tqdm
(
reqs
):
if
my_rank
==
0
:
buf
.
insert
(
*
req
)
else
:
tok
,
roi
,
k
,
v
,
h
=
req
tok_
,
roi_
,
k_
,
v_
,
h_
=
buf
.
drop_select
(
tok
,
roi
)
if
tok_
is
None
:
assert
roi_
is
None
assert
k_
is
None
assert
v_
is
None
assert
h_
is
None
n
+=
1
else
:
assert
torch
.
allclose
(
tok
,
tok_
)
assert
torch
.
allclose
(
roi
,
roi_
)
assert
torch
.
allclose
(
k
,
k_
)
assert
torch
.
allclose
(
v
,
v_
)
assert
torch
.
allclose
(
h
,
h_
)
print
(
f
"Rank
{
my_rank
}
done"
)
torch
.
distributed
.
barrier
()
if
my_rank
==
0
:
x
=
torch
.
tensor
([
0
])
torch
.
distributed
.
recv
(
x
,
1
)
# the # of None received is the kv that are not selected
assert
x
.
item
()
==
len
(
buf
.
buffer
)
# and the size of the buffer should be 2000 * buffer len
print
(
buf
.
buffer_size
)
assert
buf
.
buffer_size
==
1700
*
len
(
buf
.
buffer
)
else
:
torch
.
distributed
.
send
(
torch
.
tensor
([
n
]),
0
)
print
(
f
"My rank:
{
my_rank
}
, Passed stress test!"
)
if
__name__
==
"__main__"
:
my_rank
=
int
(
os
.
environ
[
"RANK"
])
torch
.
distributed
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
"tcp://localhost:12398"
,
world_size
=
2
,
rank
=
my_rank
,
)
print
(
f
"initialized! My rank is
{
my_rank
}
"
)
config
=
KVTransferConfig
(
kv_connector
=
"P2pNcclConnector"
,
kv_buffer_device
=
"cuda"
,
kv_buffer_size
=
1e9
,
kv_rank
=
my_rank
,
kv_role
=
"kv_both"
,
# this arg doesn't matter in this test
kv_parallel_size
=
2
,
kv_ip
=
"127.0.0.1"
,
kv_port
=
12345
,
)
data_pipe
=
PyNcclPipe
(
local_rank
=
my_rank
,
config
=
config
,
device
=
"cuda"
,
port_offset
=
0
,
)
cpu_pipe
=
PyNcclPipe
(
local_rank
=
my_rank
,
config
=
config
,
device
=
"cpu"
,
port_offset
=
1
,
)
buffer
=
SimpleBuffer
(
cpu_pipe
,
data_pipe
,
170000
)
test_run
(
my_rank
,
buffer
,
data_pipe
.
device
)
stress_test
(
my_rank
,
buffer
,
data_pipe
.
device
)
buffer
.
close
()
data_pipe
.
close
()
cpu_pipe
.
close
()
print
(
"Done"
)
tests/kv_transfer/test_lookup_buffer.sh
deleted
100644 → 0
View file @
ce888aa4
#!/bin/bash
RANK
=
0 python3 test_lookup_buffer.py &
PID0
=
$!
RANK
=
1 python3 test_lookup_buffer.py &
PID1
=
$!
wait
$PID0
wait
$PID1
Prev
1
…
8
9
10
11
12
13
14
15
16
…
34
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment