Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
45a060d6
Commit
45a060d6
authored
Feb 05, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.15.1' into v0.15.1-dev
parents
99fc9fc3
1892993b
Changes
64
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
663 additions
and
105 deletions
+663
-105
tests/kernels/quantization/untest_rocm_skinny_gemms.py
tests/kernels/quantization/untest_rocm_skinny_gemms.py
+15
-2
tests/models/registry.py
tests/models/registry.py
+9
-0
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+343
-40
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+60
-0
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+0
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+2
-16
vllm/compilation/backends.py
vllm/compilation/backends.py
+8
-1
vllm/config/compilation.py
vllm/config/compilation.py
+31
-0
vllm/config/speculative.py
vllm/config/speculative.py
+6
-0
vllm/forward_context.py
vllm/forward_context.py
+20
-12
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+90
-0
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+6
-1
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+1
-1
vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
...model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
+2
-1
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
...model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
+14
-13
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
.../model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
+39
-8
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+1
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+3
-1
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+9
-4
vllm/model_executor/layers/fused_moe/oracle/fp8.py
vllm/model_executor/layers/fused_moe/oracle/fp8.py
+4
-4
No files found.
tests/kernels/quantization/untest_rocm_skinny_gemms.py
View file @
45a060d6
...
...
@@ -87,6 +87,13 @@ NKM_FACTORS_WVSPLITK_FP8 = [
SEEDS
=
[
0
]
def
pad_weights_fp8
(
weight
):
num_pad
=
256
//
weight
.
element_size
()
import
torch.nn.functional
as
F
return
F
.
pad
(
weight
,
(
0
,
num_pad
),
"constant"
,
0
)[...,
:
-
num_pad
]
@
pytest
.
mark
.
parametrize
(
"n,k,m"
,
NKM_FACTORS_WVSPLITKRC
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
...
...
@@ -191,11 +198,12 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
@
pytest
.
mark
.
parametrize
(
"n,k,m"
,
NKM_FACTORS_WVSPLITK_FP8
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"padded"
,
[
False
,
True
])
@
pytest
.
mark
.
skipif
(
not
(
current_platform
.
is_rocm
()
and
current_platform
.
supports_fp8
()),
reason
=
"only test for rocm fp8"
,
)
def
test_rocm_wvsplitk_fp8_kernel
(
n
,
k
,
m
,
dtype
,
seed
):
def
test_rocm_wvsplitk_fp8_kernel
(
n
,
k
,
m
,
dtype
,
seed
,
padded
):
torch
.
manual_seed
(
seed
)
A
=
torch
.
rand
(
n
,
k
,
device
=
"cuda"
)
-
0.5
...
...
@@ -203,6 +211,8 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
A
,
scale_a
=
ref_dynamic_per_tensor_fp8_quant
(
A
)
B
,
scale_b
=
ref_dynamic_per_tensor_fp8_quant
(
B
)
if
padded
:
B
=
pad_weights_fp8
(
B
)
ref_out
=
torch
.
_scaled_mm
(
A
,
B
.
t
(),
out_dtype
=
dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
...
...
@@ -222,11 +232,12 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
@
pytest
.
mark
.
parametrize
(
"n,k,m"
,
NKM_FACTORS_WVSPLITK_FP8
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"padded"
,
[
False
,
True
])
@
pytest
.
mark
.
skipif
(
not
(
current_platform
.
is_rocm
()
and
current_platform
.
supports_fp8
()),
reason
=
"only test for rocm fp8"
,
)
def
test_rocm_wvsplitk_fp8_bias1D_kernel
(
n
,
k
,
m
,
dtype
,
seed
):
def
test_rocm_wvsplitk_fp8_bias1D_kernel
(
n
,
k
,
m
,
dtype
,
seed
,
padded
):
torch
.
manual_seed
(
seed
)
xavier
=
math
.
sqrt
(
2
/
k
)
# normalize to avoid large output-bias deltas
...
...
@@ -236,6 +247,8 @@ def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
A
,
scale_a
=
ref_dynamic_per_tensor_fp8_quant
(
A
)
B
,
scale_b
=
ref_dynamic_per_tensor_fp8_quant
(
B
)
if
padded
:
B
=
pad_weights_fp8
(
B
)
ref_out
=
torch
.
_scaled_mm
(
A
,
B
.
t
(),
out_dtype
=
dtype
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
bias
=
BIAS
...
...
tests/models/registry.py
View file @
45a060d6
...
...
@@ -487,6 +487,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Step1ForCausalLM"
:
_HfExamplesInfo
(
"stepfun-ai/Step-Audio-EditX"
,
trust_remote_code
=
True
),
"Step3p5ForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"stepfun-ai/step-3.5-flash"
),
is_available_online
=
False
),
"SmolLM3ForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"HuggingFaceTB/SmolLM3-3B"
)),
"StableLMEpochForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"stabilityai/stablelm-zephyr-3b"
)),
"StableLmForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"stabilityai/stablelm-3b-4e1t"
)),
...
...
@@ -1099,6 +1102,12 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"Qwen3NextMTP"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-Next-80B-A3B-Instruct"
),
min_transformers_version
=
"4.56.3"
),
"Step3p5MTP"
:
_HfExamplesInfo
(
"stepfun-ai/Step-3.5-Flash"
,
trust_remote_code
=
True
,
speculative_model
=
"stepfun-ai/Step-3.5-Flash"
,
is_available_online
=
False
,
),
}
_TRANSFORMERS_BACKEND_MODELS
=
{
...
...
tests/v1/core/test_prefix_caching.py
View file @
45a060d6
...
...
@@ -107,7 +107,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
def
make_kv_cache_config_hybrid_model
(
block_size
:
int
,
num_blocks
:
int
,
second_spec_type
:
str
=
"sliding_window"
block_size
:
int
,
num_blocks
:
int
,
sliding_window_blocks
:
int
,
second_spec_type
:
str
=
"sliding_window"
,
)
->
KVCacheConfig
:
if
second_spec_type
==
"sliding_window"
:
second_spec
=
SlidingWindowSpec
(
...
...
@@ -115,7 +118,7 @@ def make_kv_cache_config_hybrid_model(
num_kv_heads
=
1
,
head_size
=
1
,
dtype
=
torch
.
float32
,
sliding_window
=
2
*
block_size
,
sliding_window
=
sliding_window_blocks
*
block_size
,
)
elif
second_spec_type
==
"mamba"
:
second_spec
=
MambaSpec
(
...
...
@@ -325,7 +328,7 @@ def test_prefill(hash_fn):
def
test_prefill_hybrid_model
():
block_size
=
16
manager
=
KVCacheManager
(
make_kv_cache_config_hybrid_model
(
block_size
,
21
),
make_kv_cache_config_hybrid_model
(
block_size
,
21
,
2
),
max_model_len
=
8192
,
enable_caching
=
True
,
hash_block_size
=
block_size
,
...
...
@@ -334,7 +337,8 @@ def test_prefill_hybrid_model():
hash_fn
=
sha256
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
num_full_blocks
=
3
common_token_ids
=
[
i
for
i
in
range
(
num_full_blocks
)
for
_
in
range
(
block_size
)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
...
...
@@ -375,6 +379,7 @@ def test_prefill_hybrid_model():
# Cache hit in the common prefix
# Incomplete 1 block (5 tokens)
unique_token_ids
=
[
3
]
*
5
all_token_ids
=
common_token_ids
+
unique_token_ids
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
req1
.
block_hashes
)
==
3
...
...
@@ -394,34 +399,13 @@ def test_prefill_hybrid_model():
manager
.
free
(
req0
)
manager
.
free
(
req1
)
cached_block_hash_to_block_bak
=
copy
.
copy
(
manager
.
block_pool
.
cached_block_hash_to_block
.
_cache
)
def
test_partial_request_hit
(
request_id
:
str
,
hash_to_evict
:
list
[
BlockHashWithGroupId
],
expect_hit_length
:
int
,
):
req
=
make_request
(
request_id
,
common_token_ids
+
unique_token_ids
,
block_size
,
sha256
)
for
hash_with_group_id
in
hash_to_evict
:
manager
.
block_pool
.
cached_block_hash_to_block
.
_cache
.
pop
(
hash_with_group_id
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
len
(
req
.
block_hashes
)
==
3
assert
num_computed_tokens
==
expect_hit_length
*
block_size
for
block_per_group
in
computed_blocks
.
blocks
:
assert
len
(
block_per_group
)
==
num_computed_tokens
//
block_size
for
hash_with_group_id
in
hash_to_evict
:
manager
.
block_pool
.
cached_block_hash_to_block
.
_cache
[
hash_with_group_id
]
=
(
cached_block_hash_to_block_bak
[
hash_with_group_id
]
)
manager
.
free
(
req
)
# Evict the blocks outside sliding window, does not affect the hit length.
test_partial_request_hit
(
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"2"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
0
],
1
),
make_block_hash_with_group_id
(
block_hashes
[
0
],
2
),
...
...
@@ -430,13 +414,23 @@ def test_prefill_hybrid_model():
)
# Evict the first block of full attention, makes total cache miss.
test_partial_request_hit
(
"3"
,
[
make_block_hash_with_group_id
(
block_hashes
[
0
],
0
)],
0
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"3"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
0
],
0
)],
0
,
)
# Evict the last block of all layers, reduces the hit length to 2.
test_partial_request_hit
(
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"4"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
0
),
make_block_hash_with_group_id
(
block_hashes
[
2
],
1
),
...
...
@@ -446,18 +440,36 @@ def test_prefill_hybrid_model():
)
# Evict the last block of full attention, reduces the hit length to 2.
test_partial_request_hit
(
"5"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
0
)],
2
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"5"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
0
)],
2
,
)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit
(
"6"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
1
)],
2
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"6"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
1
)],
2
,
)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit
(
"7"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
2
)],
2
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"7"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
2
)],
2
,
)
# Evict different set of blocks for full attention and sliding window makes
...
...
@@ -466,8 +478,12 @@ def test_prefill_hybrid_model():
# The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers
# have different hit length.
test_partial_request_hit
(
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"8"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
0
),
make_block_hash_with_group_id
(
block_hashes
[
0
],
1
),
...
...
@@ -477,6 +493,214 @@ def test_prefill_hybrid_model():
)
def
test_prefill_hybrid_model_eagle
():
block_size
=
16
kv_cache_config
=
make_kv_cache_config_hybrid_model
(
block_size
,
31
,
3
)
manager
=
KVCacheManager
(
kv_cache_config
,
max_model_len
=
8192
,
enable_caching
=
True
,
hash_block_size
=
block_size
,
use_eagle
=
True
,
)
hash_fn
=
sha256
# Complete 6 blocks (96 tokens)
num_full_blocks
=
6
common_token_ids
=
[
i
for
i
in
range
(
num_full_blocks
)
for
_
in
range
(
block_size
)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
6
]
*
7
all_token_ids
=
common_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
req0
.
block_hashes
)
==
len
(
all_token_ids
)
//
block_size
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
len
(
all_token_ids
),
num_computed_tokens
,
computed_blocks
)
block_ids
=
(
[
1
,
2
,
3
,
4
,
5
,
6
,
7
],
[
8
,
9
,
10
,
11
,
12
,
13
,
14
],
[
15
,
16
,
17
,
18
,
19
,
20
,
21
],
)
assert
blocks
is
not
None
and
blocks
.
get_block_ids
()
==
block_ids
# Check full block metadata
parent_block_hash
=
None
for
i
,
full_block_ids
in
enumerate
(
zip
(
*
(
row
[:
-
1
]
for
row
in
block_ids
))):
block_tokens
=
tuple
(
all_token_ids
[
i
*
block_size
:
(
i
+
1
)
*
block_size
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
for
group_id
,
block_id
in
enumerate
(
full_block_ids
):
blk_hash
=
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
assert
blk_hash
is
not
None
assert
get_block_hash
(
blk_hash
)
==
block_hash
assert
get_group_id
(
blk_hash
)
==
group_id
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
# Check partial block metadata
for
partial_block_id
in
(
row
[
-
1
]
for
row
in
block_ids
):
assert
manager
.
block_pool
.
blocks
[
partial_block_id
].
block_hash
is
None
assert
manager
.
block_pool
.
blocks
[
partial_block_id
].
ref_cnt
==
1
# Cache hit in the common prefix
# Incomplete 1 block (5 tokens)
unique_token_ids
=
[
6
]
*
5
all_token_ids
=
common_token_ids
+
unique_token_ids
req1
=
make_request
(
"1"
,
all_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
req1
.
block_hashes
)
==
num_full_blocks
assert
computed_blocks
.
get_block_ids
()
==
(
[
1
,
2
,
3
,
4
],
[
0
,
9
,
10
,
11
],
[
0
,
16
,
17
,
18
],
)
assert
num_computed_tokens
==
4
*
block_size
num_new_tokens
=
len
(
all_token_ids
)
-
num_computed_tokens
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
num_computed_tokens
,
computed_blocks
)
assert
blocks
is
not
None
and
blocks
.
get_block_ids
()
==
(
[
22
,
23
,
24
],
[
25
,
26
,
27
],
[
28
,
29
,
30
],
)
for
block_per_group
in
computed_blocks
.
blocks
:
for
block
in
block_per_group
:
if
block
!=
manager
.
block_pool
.
null_block
:
assert
block
.
ref_cnt
==
2
block_hashes
=
req1
.
block_hashes
manager
.
free
(
req0
)
manager
.
free
(
req1
)
# Evict the blocks outside sliding window, does not affect the hit length.
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"2"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
0
],
1
),
make_block_hash_with_group_id
(
block_hashes
[
0
],
2
),
],
4
,
)
# Evict the first block of full attention, makes total cache miss.
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"3"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
0
],
0
)],
0
,
)
# Evict the last block of all layers, reduces the hit length to 3.
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"4"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
-
1
],
0
),
make_block_hash_with_group_id
(
block_hashes
[
-
1
],
1
),
make_block_hash_with_group_id
(
block_hashes
[
-
1
],
2
),
],
3
,
)
# Evict the last block of full attention, reduces the hit length to 3.
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"5"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
-
1
],
0
)],
3
,
)
# Since the last block of full attention is dropped for eagle, evict
# the second last block of sliding window, reduces the hit length to 3.
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"6"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
-
2
],
1
)],
3
,
)
# Since the last block of full attention is dropped for eagle, evict
# the second last block of sliding window, reduces the hit length to 3.
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"7"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
-
2
],
2
)],
3
,
)
# Evict different set of blocks for full attention and sliding window makes
# total cache miss.
# The cache hit length of full attention is 4 * block_size.
# The cache hit length of sliding window is 3 * block_size.
# Then it is cache miss as the two type of layers
# have different hit length.
_test_partial_request_hit
(
manager
,
block_size
,
num_full_blocks
,
"8"
,
all_token_ids
,
[
make_block_hash_with_group_id
(
block_hashes
[
-
1
],
0
),
make_block_hash_with_group_id
(
block_hashes
[
0
],
1
),
make_block_hash_with_group_id
(
block_hashes
[
0
],
2
),
],
0
,
)
def
_test_partial_request_hit
(
manager
:
KVCacheManager
,
block_size
:
int
,
num_full_blocks
,
request_id
:
str
,
prompt_token_ids
:
list
[
int
],
hash_to_evict
:
list
[
BlockHashWithGroupId
],
expect_hit_length
:
int
,
):
cached_block_hash_to_block_bak
=
copy
.
copy
(
manager
.
block_pool
.
cached_block_hash_to_block
.
_cache
)
req
=
make_request
(
request_id
,
prompt_token_ids
,
block_size
,
sha256
)
for
hash_with_group_id
in
hash_to_evict
:
manager
.
block_pool
.
cached_block_hash_to_block
.
_cache
.
pop
(
hash_with_group_id
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
len
(
req
.
block_hashes
)
==
num_full_blocks
assert
num_computed_tokens
==
expect_hit_length
*
block_size
for
block_per_group
in
computed_blocks
.
blocks
:
assert
len
(
block_per_group
)
==
num_computed_tokens
//
block_size
for
hash_with_group_id
in
hash_to_evict
:
manager
.
block_pool
.
cached_block_hash_to_block
.
_cache
[
hash_with_group_id
]
=
(
cached_block_hash_to_block_bak
[
hash_with_group_id
]
)
manager
.
free
(
req
)
def
_make_hybrid_kv_cache_config
(
block_size
:
int
,
num_blocks
:
int
,
spec_types
:
list
[
str
]
)
->
KVCacheConfig
:
...
...
@@ -655,6 +879,85 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]):
manager
.
free
(
req1
)
# Test cases with eagle enabled: Only test a single simple case for now.
# - 2 groups: 1 full + 1 other
_EAGLE_HYBRID_MODEL_TEST_CASES
=
[
# 2 groups: 1 full + 1 other
pytest
.
param
([
"full"
,
"sliding_window"
],
2
,
id
=
"2g-full+sw"
),
]
@
pytest
.
mark
.
parametrize
(
"spec_types,expect_hit_length"
,
_EAGLE_HYBRID_MODEL_TEST_CASES
)
def
test_prefill_hybrid_model_combinations_eagle
(
spec_types
:
list
[
str
],
expect_hit_length
:
int
):
"""
Test prefix caching with hybrid models (1 full attn + 1 other) with EAGLE.
More complex hybrid models with EAGLE are not yet supported (see issue #32802).
"""
block_size
=
16
num_groups
=
len
(
spec_types
)
# Allocate enough blocks for all groups
num_blocks
=
10
*
num_groups
kv_cache_config
=
_make_hybrid_kv_cache_config
(
block_size
,
num_blocks
,
spec_types
)
manager
=
KVCacheManager
(
kv_cache_config
,
max_model_len
=
8192
,
enable_caching
=
True
,
hash_block_size
=
block_size
,
use_eagle
=
True
,
)
hash_fn
=
sha256
# Complete 3 blocks (48 tokens)
num_full_blocks
=
4
common_token_ids
=
[
i
for
i
in
range
(
num_full_blocks
)
for
_
in
range
(
block_size
)]
unique_token_ids
=
[
4
]
*
7
all_token_ids
=
common_token_ids
+
unique_token_ids
# First request: no cache hit initially
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
req0
.
block_hashes
)
==
num_full_blocks
assert
not
computed_blocks
.
blocks
[
0
]
# No cache hit initially
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
len
(
all_token_ids
),
num_computed_tokens
,
computed_blocks
)
assert
blocks
is
not
None
# Should have blocks for all groups
assert
len
(
blocks
.
get_block_ids
())
==
num_groups
# Second request: should hit cached blocks for common prefix
all_token_ids
=
common_token_ids
+
[
6
]
*
5
req1
=
make_request
(
"1"
,
all_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
# Should hit cached blocks for all groups
assert
num_computed_tokens
==
expect_hit_length
*
block_size
assert
len
(
computed_blocks
.
blocks
)
==
num_groups
# Verify each group has the correct number of computed blocks
for
block_per_group
in
computed_blocks
.
blocks
:
assert
len
(
block_per_group
)
==
expect_hit_length
# Allocate and verify blocks for second request
blocks
=
manager
.
allocate_slots
(
req1
,
len
(
all_token_ids
)
-
num_computed_tokens
,
num_computed_tokens
,
computed_blocks
,
)
assert
blocks
is
not
None
assert
len
(
blocks
.
get_block_ids
())
==
num_groups
manager
.
free
(
req0
)
manager
.
free
(
req1
)
def
test_prefill_plp
():
"""Test prefill with APC and some prompt logprobs (plp) requests.
...
...
tests/v1/core/test_scheduler.py
View file @
45a060d6
...
...
@@ -873,6 +873,66 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
assert
stats
.
num_accepted_tokens_per_pos
==
expected
[
3
]
def
test_spec_decoding_stats_empty_output
():
"""Test that spec decoding stats handle empty output tokens gracefully.
This is a regression test for a bug where empty sampled_token_ids
would cause num_accepted = len([]) - 1 = -1, leading to a
ValueError when incrementing a Prometheus counter with a negative value.
"""
num_spec_tokens
=
3
scheduler
=
create_scheduler
(
num_speculative_tokens
=
num_spec_tokens
)
requests
=
create_requests
(
num_requests
=
1
,
num_tokens
=
1
)
request
=
requests
[
0
]
req_id
=
request
.
request_id
scheduler
.
add_request
(
request
)
# Initial schedule (prefill)
output
=
scheduler
.
schedule
()
assert
len
(
output
.
scheduled_new_reqs
)
==
1
# Complete the prefill with a sampled token
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
[
req_id
],
req_id_to_index
=
{
req_id
:
0
},
sampled_token_ids
=
[[
0
]],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[],
)
scheduler
.
update_from_output
(
output
,
model_runner_output
)
# Add draft tokens for speculation
draft_token_ids
=
DraftTokenIds
([
req_id
],
[[
1
,
2
,
3
]])
scheduler
.
update_draft_token_ids
(
draft_token_ids
)
# Schedule the speculated tokens for validation
output
=
scheduler
.
schedule
()
assert
req_id
in
output
.
scheduled_spec_decode_tokens
assert
len
(
output
.
scheduled_spec_decode_tokens
[
req_id
])
==
3
# Simulate empty output tokens (e.g., due to request abortion or error)
# This would previously cause num_accepted = -1 and crash
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
[
req_id
],
req_id_to_index
=
{
req_id
:
0
},
sampled_token_ids
=
[[]],
# Empty output tokens
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[],
)
# This should not raise an error
engine_core_outputs
=
scheduler
.
update_from_output
(
output
,
model_runner_output
)
# Spec decoding stats should be None since no tokens were generated
scheduler_stats
=
(
engine_core_outputs
[
0
].
scheduler_stats
if
engine_core_outputs
else
None
)
assert
scheduler_stats
is
None
or
scheduler_stats
.
spec_decoding_stats
is
None
def
_assert_right_scheduler_output
(
output
:
SchedulerOutput
,
num_requests
:
int
,
...
...
tests/weight_loading/models.txt
View file @
45a060d6
...
...
@@ -19,7 +19,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
#compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90
...
...
vllm/_custom_ops.py
View file @
45a060d6
...
...
@@ -1117,6 +1117,8 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
def
cutlass_group_gemm_supported
(
cuda_device_capability
:
int
)
->
bool
:
if
cuda_device_capability
<
90
or
cuda_device_capability
>=
110
:
return
False
try
:
return
torch
.
ops
.
_C
.
cutlass_group_gemm_supported
(
cuda_device_capability
)
except
AttributeError
:
...
...
@@ -2249,35 +2251,20 @@ def selective_scan_fwd(
)
# NOTE: The wvSplitK kernel (and all of the kernels in skinny_gemms.cu)
# are unable to properly handle non-contiguous
# tensors. It might be a good TODO(rasmith) to augment these kernels
# to be able to handle non-contiguous kernels for better performance.
def
rocm_enforce_contiguous_skinny_gemm_inputs
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
a
=
a
.
contiguous
()
# no-op if already contiguous, else clone
b
=
b
.
contiguous
()
# no-op if already contiguous, else clone
return
a
,
b
# ROCm skinny gemms
def
LLMM1
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
rows_per_block
:
int
)
->
torch
.
Tensor
:
a
,
b
=
rocm_enforce_contiguous_skinny_gemm_inputs
(
a
,
b
)
return
torch
.
ops
.
_rocm_C
.
LLMM1
(
a
,
b
,
rows_per_block
)
def
wvSplitK
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
cu_count
:
int
,
bias
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
a
,
b
=
rocm_enforce_contiguous_skinny_gemm_inputs
(
a
,
b
)
return
torch
.
ops
.
_rocm_C
.
wvSplitK
(
a
,
b
,
bias
,
cu_count
)
def
wvSplitKrc
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
cu_count
:
int
,
bias
:
torch
.
Tensor
=
None
)
->
torch
.
Tensor
:
a
,
b
=
rocm_enforce_contiguous_skinny_gemm_inputs
(
a
,
b
)
return
torch
.
ops
.
_rocm_C
.
wvSplitKrc
(
a
,
b
,
bias
,
cu_count
)
...
...
@@ -2290,7 +2277,6 @@ def wvSplitKQ(
cu_count
:
int
,
bias
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
a
,
b
=
rocm_enforce_contiguous_skinny_gemm_inputs
(
a
,
b
)
out
=
torch
.
empty
((
b
.
shape
[
0
],
a
.
shape
[
0
]),
dtype
=
out_dtype
,
device
=
b
.
device
)
torch
.
ops
.
_rocm_C
.
wvSplitKQ
(
a
,
b
,
bias
,
out
,
scale_a
,
scale_b
,
cu_count
)
return
out
...
...
vllm/compilation/backends.py
View file @
45a060d6
...
...
@@ -361,7 +361,14 @@ def split_graph(
subgraph_id
+=
1
node_to_subgraph_id
[
node
]
=
subgraph_id
split_op_graphs
.
append
(
subgraph_id
)
subgraph_id
+=
1
# keep consecutive splitting ops together
# (we know node.next exists because node isn't the last (output) node)
if
should_split
(
node
.
next
,
splitting_ops
):
# this will get incremented by the next node
subgraph_id
-=
1
else
:
subgraph_id
+=
1
else
:
node_to_subgraph_id
[
node
]
=
subgraph_id
...
...
vllm/config/compilation.py
View file @
45a060d6
...
...
@@ -583,6 +583,24 @@ class CompilationConfig:
local_cache_dir
:
str
=
field
(
default
=
None
,
init
=
False
)
# type: ignore
"""local cache dir for each rank"""
fast_moe_cold_start
=
True
"""Optimization for fast MOE cold start.
This is a bit of a hack that assumes that:
1. the only decoder forward pass being run is the current model
2. the decoder forward pass runs all of the MOEs in the order in which they
are initialized
When the above two conditions hold, this option greatly decreases cold start
time for MOE models.
If the above two conditions don't hold, then this option will lead to silent
incorrectness. The only condition in which this doesn't hold is speculative
decoding, where there is a draft model that may have MOEs in them.
NB: We're working on a longer-term solution that doesn't need these assumptions.
"""
# keep track of enabled and disabled custom ops
enabled_custom_ops
:
Counter
[
str
]
=
field
(
default_factory
=
Counter
,
init
=
False
)
"""custom ops that are enabled"""
...
...
@@ -598,6 +616,10 @@ class CompilationConfig:
Map from layer name to layer objects that need to be accessed outside
model code, e.g., Attention, FusedMOE when dp_size>1."""
static_all_moe_layers
:
list
[
str
]
=
field
(
default_factory
=
list
,
init
=
False
)
"""The names of all the MOE layers in the model
"""
# Attention ops; used for piecewise cudagraphs
# Use PyTorch operator format: "namespace::name"
_attention_ops
:
ClassVar
[
list
[
str
]]
=
[
...
...
@@ -927,6 +949,15 @@ class CompilationConfig:
# for details. Make a copy to avoid mutating the class-level
# list via reference.
self
.
splitting_ops
=
list
(
self
.
_attention_ops
)
# unified_kv_cache_update has a string param that prevents Inductor
# from reusing piecewise graphs. Remove it from the compiled graph.
# This has the side-effect of excluding cache from cudagraphs but
# that doesn't seem to affect performance.
# https://github.com/vllm-project/vllm/issues/33267
if
not
self
.
use_inductor_graph_partition
:
self
.
splitting_ops
.
append
(
"vllm::unified_kv_cache_update"
)
elif
len
(
self
.
splitting_ops
)
==
0
:
if
(
self
.
cudagraph_mode
==
CUDAGraphMode
.
PIECEWISE
...
...
vllm/config/speculative.py
View file @
45a060d6
...
...
@@ -41,6 +41,7 @@ MTPModelTypes = Literal[
"longcat_flash_mtp"
,
"mtp"
,
"pangu_ultra_moe_mtp"
,
"step3p5_mtp"
,
]
EagleModelTypes
=
Literal
[
"eagle"
,
"eagle3"
,
MTPModelTypes
]
SpeculativeMethod
=
Literal
[
...
...
@@ -264,6 +265,11 @@ class SpeculativeConfig:
{
"n_predict"
:
n_predict
,
"architectures"
:
[
"LongCatFlashMTPModel"
]}
)
if
hf_config
.
model_type
==
"step3p5"
:
hf_config
.
model_type
=
"step3p5_mtp"
n_predict
=
getattr
(
hf_config
,
"num_nextn_predict_layers"
,
1
)
hf_config
.
update
({
"n_predict"
:
n_predict
,
"architectures"
:
[
"Step3p5MTP"
]})
if
initial_architecture
==
"MistralLarge3ForCausalLM"
:
hf_config
.
update
({
"architectures"
:
[
"EagleMistralLarge3ForCausalLM"
]})
...
...
vllm/forward_context.py
View file @
45a060d6
...
...
@@ -219,9 +219,11 @@ class ForwardContext:
# the graph.
#
# The workaround is to store a list of the strings that each of those
# custom ops needs, in reverse order, in the ForwardContext.
# custom ops needs in the ForwardContext (all_moe_layers)
# as well as a counter (moe_layer_index).
# The ForwardContext object is alive for the duration of the forward pass.
# When the custom op needs the string, pop the string from this list.
# When the custom op needs a layer string, get the next string
# from all_moe_layers and increment the counter.
#
# This assumes that the custom operators will always be executed in
# order and that torch.compile will not try to reorder these
...
...
@@ -235,7 +237,8 @@ class ForwardContext:
#
# If this value is None (like in some tests), then we end up baking the string
# into the graph. Otherwise, the moe custom ops will pop a string from this list.
remaining_moe_layers
:
list
[
str
]
|
None
=
None
all_moe_layers
:
list
[
str
]
|
None
=
None
moe_layer_index
:
int
=
0
additional_kwargs
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
...
...
@@ -273,17 +276,22 @@ def create_forward_context(
additional_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
skip_compiled
:
bool
=
False
,
):
no_compile_layers
=
vllm_config
.
compilation_config
.
static_forward_context
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
remaining_moe_layers
=
[
name
for
name
,
layer
in
no_compile_layers
.
items
()
if
isinstance
(
layer
,
FusedMoE
)
]
remaining_moe_layers
.
reverse
()
if
vllm_config
.
compilation_config
.
fast_moe_cold_start
:
if
vllm_config
.
speculative_config
is
None
:
all_moe_layers
=
vllm_config
.
compilation_config
.
static_all_moe_layers
else
:
logger
.
warning_once
(
"vllm_config.compilation_config.fast_moe_cold_start is not "
"compatible with speculative decoding so we are ignoring "
"fast_moe_cold_start."
)
all_moe_layers
=
None
else
:
all_moe_layers
=
None
return
ForwardContext
(
no_compile_layers
=
no_compile_layers
,
remaining
_moe_layers
=
remaining
_moe_layers
,
no_compile_layers
=
vllm_config
.
compilation_config
.
static_forward_context
,
all
_moe_layers
=
all
_moe_layers
,
virtual_engine
=
virtual_engine
,
attn_metadata
=
attn_metadata
,
slot_mapping
=
slot_mapping
or
{},
...
...
vllm/model_executor/layers/activation.py
View file @
45a060d6
...
...
@@ -17,12 +17,64 @@ from vllm.logger import init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.collection_utils
import
LazyDict
import
vllm.envs
as
envs
logger
=
init_logger
(
__name__
)
@
triton
.
jit
def
_swiglustep_and_mul_kernel
(
o_ptr
,
o_stride
,
x_ptr
,
x_stride
,
limit
:
tl
.
constexpr
,
d
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
)
->
None
:
i
=
tl
.
program_id
(
axis
=
0
).
to
(
tl
.
int64
)
j
=
tl
.
program_id
(
axis
=
1
)
o_row_ptr
=
o_ptr
+
o_stride
*
i
x_row_ptr
=
x_ptr
+
x_stride
*
i
offsets
=
j
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offsets
<
d
gate
=
tl
.
load
(
x_row_ptr
+
offsets
,
mask
=
mask
).
to
(
tl
.
float32
)
up
=
tl
.
load
(
x_row_ptr
+
offsets
+
d
,
mask
=
mask
).
to
(
tl
.
float32
)
gate_silu
=
tl
.
sigmoid
(
gate
)
*
gate
gate_clamped
=
tl
.
minimum
(
gate_silu
,
limit
)
up_clamped
=
tl
.
minimum
(
tl
.
maximum
(
up
,
-
limit
),
limit
)
result
=
gate_clamped
*
up_clamped
result
=
result
.
to
(
x_ptr
.
dtype
.
element_ty
)
tl
.
store
(
o_row_ptr
+
offsets
,
result
,
mask
=
mask
)
def
swiglustep_and_mul_triton
(
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
limit
:
float
=
7.0
):
b
,
n
=
input
.
shape
assert
input
.
ndim
==
2
assert
n
%
2
==
0
d
=
n
//
2
def
grid
(
meta
):
return
(
b
,
triton
.
cdiv
(
d
,
meta
[
"BLOCK_SIZE"
]))
_swiglustep_and_mul_kernel
[
grid
](
output
,
output
.
stride
(
0
),
input
,
input
.
stride
(
0
),
limit
=
limit
,
d
=
d
,
BLOCK_SIZE
=
1024
,
)
# --8<-- [start:fatrelu_and_mul]
@
CustomOp
.
register
(
"fatrelu_and_mul"
)
class
FatreluAndMul
(
CustomOp
):
...
...
@@ -317,6 +369,44 @@ class SwigluOAIAndMul(CustomOp):
return
f
"alpha=
{
repr
(
self
.
alpha
)
}
, limit=
{
repr
(
self
.
limit
)
}
"
# --8<-- [start:swiglustep_and_mul]
@
CustomOp
.
register
(
"swiglustep_and_mul"
)
class
SwigluStepAndMul
(
CustomOp
):
"""An activation function for SwiGLU with clamping.
Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def
__init__
(
self
,
limit
:
float
=
7.0
):
super
().
__init__
()
if
limit
is
None
:
raise
ValueError
(
"SwigluStepAndMul requires limit to be set."
)
self
.
limit
=
limit
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
gate
,
up
=
x
.
chunk
(
2
,
dim
=-
1
)
gate
=
F
.
silu
(
gate
)
gate
=
gate
.
clamp
(
max
=
self
.
limit
)
up
=
up
.
clamp
(
min
=-
self
.
limit
,
max
=
self
.
limit
)
return
gate
*
up
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
x
.
shape
[:
-
1
]
+
(
d
,)
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
swiglustep_and_mul_triton
(
out
,
x
,
self
.
limit
)
return
out
def
extra_repr
(
self
)
->
str
:
return
f
"limit=
{
repr
(
self
.
limit
)
}
"
# --8<-- [start:gelu_new]
@
CustomOp
.
register
(
"gelu_new"
)
class
NewGELU
(
CustomOp
):
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
45a060d6
...
...
@@ -657,7 +657,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
@
staticmethod
def
_supports_current_device
()
->
bool
:
return
current_platform
.
has_device_capability
((
10
,
0
))
p
=
current_platform
return
p
.
is_cuda
()
and
(
p
.
is_device_capability_family
(
100
)
or
p
.
is_device_capability_family
(
110
)
or
p
.
is_device_capability_family
(
120
)
)
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
...
...
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
45a060d6
...
...
@@ -144,7 +144,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
return
activation
in
[
"silu"
]
return
activation
in
[
"silu"
,
"swiglustep"
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
...
...
vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
View file @
45a060d6
...
...
@@ -54,7 +54,8 @@ class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
@
staticmethod
def
_supports_current_device
()
->
bool
:
return
current_platform
.
is_device_capability_family
(
100
)
p
=
current_platform
return
p
.
is_cuda
()
and
p
.
is_device_capability_family
(
100
)
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
...
...
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
View file @
45a060d6
...
...
@@ -84,11 +84,14 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
@
staticmethod
def
_supports_current_device
()
->
bool
:
p
=
current_platform
return
(
current_platform
.
is_cuda
()
p
.
is_cuda
()
and
(
current_platform
.
is_device_capability
((
9
,
0
))
or
current_platform
.
is_device_capability_family
(
100
)
p
.
is_device_capability
(
90
)
or
p
.
is_device_capability_family
(
100
)
or
p
.
is_device_capability_family
(
110
)
or
p
.
is_device_capability_family
(
120
)
)
and
has_flashinfer_cutlass_fused_moe
()
)
...
...
@@ -102,29 +105,27 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
# The following are supported by FlashInferExperts:
# * unquantized
# * fp8 static per-tensor on 9.0+
# * fp8 block on 9.0
# * nvfp4 on 10.0+
p
=
current_platform
scheme
=
(
weight_key
,
activation_key
)
# The following are supported by FlashInferExperts:
return
(
# unquantized and fp8 static per-tensor on 9.0+
(
scheme
in
[
(
None
,
None
),
(
kFp8StaticTensorSym
,
kFp8StaticTensorSym
),
]
and
p
.
has_device_capability
(
90
)
)
# fp8 block-scale on 9.0
or
(
(
scheme
==
(
kFp8Static128BlockSym
,
kFp8Dynamic128Sym
)
)
and
(
p
.
is_device_capability
(
(
9
,
0
))
)
scheme
==
(
kFp8Static128BlockSym
,
kFp8Dynamic128Sym
)
and
p
.
is_device_capability
(
90
)
)
# nvfp4 on 10.0+
or
(
(
scheme
==
(
kNvfp4Static
,
kNvfp4Dynamic
))
and
(
p
.
is_device_capability_family
(
100
))
scheme
==
(
kNvfp4Static
,
kNvfp4Dynamic
)
and
p
.
has_device_capability
(
100
)
)
)
...
...
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
View file @
45a060d6
...
...
@@ -30,7 +30,6 @@ from vllm.utils.torch_utils import direct_register_custom_op
def
_supports_current_device
()
->
bool
:
"""Supports only Blackwell-family GPUs."""
p
=
current_platform
# Add check flashinfer trtllm is available
return
p
.
is_cuda
()
and
p
.
is_device_capability_family
(
100
)
...
...
@@ -70,9 +69,14 @@ def _supports_routing_method(
RoutingMethodType
.
RenormalizeNaive
,
]
elif
(
weight_key
,
activation_key
)
==
(
kFp8StaticTensorSym
,
kFp8StaticTensorSym
):
# NOTE(rob): kernel requires Llama4.
return
routing_method
==
RoutingMethodType
.
Llama4
# NOTE(dbari): as above, potentially allow others here.
return
routing_method
in
[
RoutingMethodType
.
Llama4
,
# NOTE(mgoin): Disabled to investigate accuracy issues.
# See https://github.com/vllm-project/vllm/issues/33532
# RoutingMethodType.Renormalize,
# RoutingMethodType.RenormalizeNaive,
]
else
:
raise
ValueError
(
"Unsupported quantization scheme."
)
...
...
@@ -82,7 +86,23 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo
return
not
moe_parallel_config
.
enable_eplb
def
is_supported_config_trtllm
(
def
_supports_router_logits_dtype
(
router_logits_dtype
:
torch
.
dtype
|
None
,
routing_method
:
RoutingMethodType
,
)
->
bool
:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
"""
if
router_logits_dtype
==
torch
.
float32
:
# Only DeepSeekV3 routing handles float32 logits
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return
routing_method
==
RoutingMethodType
.
DeepSeekV3
return
True
def
is_supported_config_trtllm_fp8
(
moe_config
:
FusedMoEConfig
,
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
...
...
@@ -111,13 +131,17 @@ def is_supported_config_trtllm(
return
False
,
_make_reason
(
"routing method"
)
elif
activation_format
!=
mk
.
FusedMoEActivationFormat
.
Standard
:
return
False
,
_make_reason
(
"activation format"
)
elif
not
_supports_router_logits_dtype
(
moe_config
.
router_logits_dtype
,
moe_config
.
routing_method
):
return
False
,
_make_reason
(
"float32 router_logits with non-DeepSeekV3 routing"
)
return
True
,
None
def
flashinfer_fused_moe_blockscale_fp8
(
routing_logits
:
torch
.
Tensor
,
routing_bias
:
torch
.
Tensor
,
routing_bias
:
torch
.
Tensor
|
None
,
x
:
torch
.
Tensor
,
w13_weight
:
torch
.
Tensor
,
w13_weight_scale_inv
:
torch
.
Tensor
,
...
...
@@ -131,7 +155,7 @@ def flashinfer_fused_moe_blockscale_fp8(
expert_offset
:
int
,
local_num_experts
:
int
,
block_shape
:
list
[
int
],
routing_method_type
:
int
=
int
(
RoutingMethodType
.
DeepSeekV3
)
,
routing_method_type
:
int
,
routed_scaling
:
float
|
None
=
1.0
,
)
->
torch
.
Tensor
:
from
vllm.utils.flashinfer
import
flashinfer_trtllm_fp8_block_scale_moe
...
...
@@ -144,6 +168,13 @@ def flashinfer_fused_moe_blockscale_fp8(
# Routing kernel expects #experts <= #threads 512
assert
global_num_experts
<=
512
# The DeepSeekV3 routing method requires float32 router logits.
if
routing_method_type
==
RoutingMethodType
.
DeepSeekV3
:
routing_logits
=
routing_logits
.
to
(
torch
.
float32
)
if
routing_bias
is
not
None
:
routing_bias
=
routing_bias
.
to
(
x
.
dtype
)
a_q
,
a_sf
=
per_token_group_quant_fp8
(
x
,
block_shape
[
1
])
# NOTE: scales of hidden states have to be transposed!
a_sf_t
=
a_sf
.
t
().
contiguous
()
...
...
@@ -171,7 +202,7 @@ def flashinfer_fused_moe_blockscale_fp8(
def
flashinfer_fused_moe_blockscale_fp8_fake
(
routing_logits
:
torch
.
Tensor
,
routing_bias
:
torch
.
Tensor
,
routing_bias
:
torch
.
Tensor
|
None
,
x
:
torch
.
Tensor
,
w13_weight
:
torch
.
Tensor
,
w13_weight_scale_inv
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
View file @
45a060d6
...
...
@@ -933,6 +933,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
SUPPORTED_W_A_FP8
=
[
(
kFp8Static128BlockSym
,
kFp8Dynamic128Sym
),
(
kFp8StaticChannelSym
,
kFp8DynamicTokenSym
),
(
kFp8StaticTensorSym
,
kFp8DynamicTokenSym
),
(
kFp8StaticTensorSym
,
kFp8StaticTensorSym
),
(
kFp8StaticTensorSym
,
kFp8DynamicTensorSym
),
]
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
45a060d6
...
...
@@ -57,6 +57,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Sc
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8Dynamic128Sym
,
kFp8DynamicTensorSym
,
kFp8DynamicTokenSym
,
kFp8Static128BlockSym
,
kFp8StaticChannelSym
,
...
...
@@ -2312,12 +2313,13 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
(
kFp8StaticChannelSym
,
kFp8DynamicTokenSym
),
(
kFp8StaticTensorSym
,
kFp8DynamicTokenSym
),
(
kFp8StaticTensorSym
,
kFp8StaticTensorSym
),
(
kFp8StaticTensorSym
,
kFp8DynamicTensorSym
),
]
return
(
weight_key
,
activation_key
)
in
SUPPORTED_W_A
@
staticmethod
def
_supports_activation
(
activation
:
str
)
->
bool
:
return
activation
in
[
"silu"
,
"gelu"
,
"swigluoai"
]
return
activation
in
[
"silu"
,
"gelu"
,
"swigluoai"
,
"swiglustep"
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
45a060d6
...
...
@@ -412,6 +412,7 @@ class FusedMoE(CustomOp):
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
"Duplicate layer name: {}"
.
format
(
prefix
))
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_all_moe_layers
.
append
(
prefix
)
self
.
layer_name
=
prefix
self
.
enable_eplb
=
enable_eplb
...
...
@@ -1606,7 +1607,7 @@ class FusedMoE(CustomOp):
# Can be unavailable or None in unittests
if
(
is_forward_context_available
()
and
get_forward_context
().
remaining
_moe_layers
is
not
None
and
get_forward_context
().
all
_moe_layers
is
not
None
):
return
"from_forward_context"
return
self
.
layer_name
...
...
@@ -2060,13 +2061,17 @@ class FusedMoE(CustomOp):
def
get_layer_from_name
(
layer_name
:
str
)
->
FusedMoE
:
forward_context
:
ForwardContext
=
get_forward_context
()
if
layer_name
==
"from_forward_context"
:
if
not
forward_context
.
remaining_moe_layers
:
all_moe_layers
=
forward_context
.
all_moe_layers
assert
all_moe_layers
is
not
None
moe_layer_index
=
forward_context
.
moe_layer_index
if
moe_layer_index
>=
len
(
all_moe_layers
):
raise
AssertionError
(
"We expected the number of MOE layers in `
remaining
_moe_layers` "
"We expected the number of MOE layers in `
all
_moe_layers` "
"to be equal to the number of "
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
)
layer_name
=
forward_context
.
remaining_moe_layers
.
pop
()
layer_name
=
all_moe_layers
[
moe_layer_index
]
forward_context
.
moe_layer_index
+=
1
self
=
cast
(
FusedMoE
,
forward_context
.
no_compile_layers
[
layer_name
])
return
self
...
...
vllm/model_executor/layers/fused_moe/oracle/fp8.py
View file @
45a060d6
...
...
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a16_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe
import
(
is_supported_config_trtllm
,
is_supported_config_trtllm
_fp8
,
)
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
FlashinferMoeBackend
,
...
...
@@ -213,7 +213,7 @@ def select_fp8_moe_backend(
if
fi_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
backend
=
Fp8MoeBackend
.
FLASHINFER_TRTLLM
supported
,
reason
=
is_supported_config_trtllm
(
supported
,
reason
=
is_supported_config_trtllm
_fp8
(
config
,
weight_key
,
activation_key
,
activation_format
)
if
supported
:
...
...
@@ -240,7 +240,7 @@ def select_fp8_moe_backend(
]:
if
backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
k_cls
=
None
supported
,
reason
=
is_supported_config_trtllm
(
supported
,
reason
=
is_supported_config_trtllm
_fp8
(
config
,
weight_key
,
activation_key
,
...
...
@@ -309,7 +309,7 @@ def select_fp8_moe_backend(
for
backend
in
AVAILABLE_BACKENDS
:
if
backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
k_cls
=
None
supported
,
reason
=
is_supported_config_trtllm
(
supported
,
reason
=
is_supported_config_trtllm
_fp8
(
config
,
weight_key
,
activation_key
,
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment