Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d75f22e
Commit
8d75f22e
authored
Dec 13, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori
parents
ce888aa4
7d80c73d
Changes
656
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
835 additions
and
237 deletions
+835
-237
tests/v1/logits_processors/test_custom_offline.py
tests/v1/logits_processors/test_custom_offline.py
+0
-5
tests/v1/logits_processors/test_custom_online.py
tests/v1/logits_processors/test_custom_online.py
+2
-10
tests/v1/logits_processors/utils.py
tests/v1/logits_processors/utils.py
+1
-1
tests/v1/metrics/test_stats.py
tests/v1/metrics/test_stats.py
+102
-1
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+6
-2
tests/v1/spec_decode/test_max_len.py
tests/v1/spec_decode/test_max_len.py
+2
-2
tests/v1/spec_decode/test_speculators_eagle3.py
tests/v1/spec_decode/test_speculators_eagle3.py
+5
-0
tests/v1/spec_decode/test_tree_attention.py
tests/v1/spec_decode/test_tree_attention.py
+2
-2
tests/v1/structured_output/test_backend_guidance.py
tests/v1/structured_output/test_backend_guidance.py
+74
-0
tests/v1/structured_output/test_reasoning_structured_output.py
.../v1/structured_output/test_reasoning_structured_output.py
+1
-0
tests/v1/test_serial_utils.py
tests/v1/test_serial_utils.py
+15
-6
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+90
-86
tests/v1/worker/test_gpu_profiler.py
tests/v1/worker/test_gpu_profiler.py
+36
-35
tools/ep_kernels/install_python_libraries.sh
tools/ep_kernels/install_python_libraries.sh
+6
-12
use_existing_torch.py
use_existing_torch.py
+2
-5
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+290
-48
vllm/_custom_ops.py
vllm/_custom_ops.py
+134
-2
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+19
-16
vllm/attention/layer.py
vllm/attention/layer.py
+47
-3
vllm/attention/layers/cross_attention.py
vllm/attention/layers/cross_attention.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
656 of 656+
files are displayed.
Plain diff
Email patch
tests/v1/logits_processors/test_custom_offline.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
sys
from
typing
import
Any
import
pytest
...
...
@@ -10,7 +9,6 @@ from tests.utils import create_new_process_for_each_test
from
tests.v1.logits_processors.utils
import
(
DUMMY_LOGITPROC_ARG
,
DUMMY_LOGITPROC_FQCN
,
DUMMY_LOGITPROC_MODULE
,
MAX_TOKENS
,
MODEL_NAME
,
POOLING_MODEL_NAME
,
...
...
@@ -18,7 +16,6 @@ from tests.v1.logits_processors.utils import (
CustomLogitprocSource
,
DummyLogitsProcessor
,
WrappedPerReqLogitsProcessor
,
dummy_module
,
prompts
,
)
from
tests.v1.logits_processors.utils
import
entry_points
as
fake_entry_points
...
...
@@ -162,8 +159,6 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource
kwargs
:
dict
[
str
,
list
[
str
|
type
[
LogitsProcessor
]]]
=
{}
if
logitproc_source
==
CustomLogitprocSource
.
LOGITPROC_SOURCE_FQCN
:
# Scenario: load logitproc based on fully-qualified class name (FQCN)
# Inject dummy module which defines logitproc
sys
.
modules
[
DUMMY_LOGITPROC_MODULE
]
=
dummy_module
kwargs
[
"logits_processors"
]
=
[
DUMMY_LOGITPROC_FQCN
]
elif
logitproc_source
==
CustomLogitprocSource
.
LOGITPROC_SOURCE_CLASS
:
# Scenario: load logitproc from provided class object
...
...
tests/v1/logits_processors/test_custom_online.py
View file @
8d75f22e
...
...
@@ -14,11 +14,9 @@ from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_te
from
tests.v1.logits_processors.utils
import
(
DUMMY_LOGITPROC_ARG
,
DUMMY_LOGITPROC_FQCN
,
DUMMY_LOGITPROC_MODULE
,
MAX_TOKENS
,
MODEL_NAME
,
TEMP_GREEDY
,
dummy_module
,
prompts
,
)
from
tests.v1.logits_processors.utils
import
entry_points
as
fake_entry_points
...
...
@@ -47,20 +45,14 @@ def _server_with_logitproc_entrypoint(
main
.
main
()
def
_server_with_logitproc_
module
(
def
_server_with_logitproc_
fqcn
(
env_dict
:
dict
[
str
,
str
]
|
None
,
model
:
str
,
vllm_serve_args
:
list
[
str
],
)
->
None
:
"""Start vLLM server, inject module with dummy logitproc"""
# Patch `modules` to inject dummy logitproc module
from
vllm.entrypoints.cli
import
main
sys
.
modules
[
DUMMY_LOGITPROC_MODULE
]
=
dummy_module
# fork is required for workers to see entrypoint patch
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"fork"
if
env_dict
is
not
None
:
os
.
environ
.
update
(
env_dict
)
...
...
@@ -99,7 +91,7 @@ def server(default_server_args, request, monkeypatch):
if
request
.
param
:
# Launch server, append FQCN argument, inject dummy logitproc module
args
=
default_server_args
+
request
.
param
_server_fxn
=
_server_with_logitproc_
module
_server_fxn
=
_server_with_logitproc_
fqcn
else
:
# Launch server, inject dummy logitproc entrypoint
args
=
default_server_args
...
...
tests/v1/logits_processors/utils.py
View file @
8d75f22e
...
...
@@ -27,7 +27,7 @@ DUMMY_LOGITPROC_ARG = "target_token"
TEMP_GREEDY
=
0.0
MAX_TOKENS
=
20
DUMMY_LOGITPROC_ENTRYPOINT
=
"dummy_logitproc"
DUMMY_LOGITPROC_MODULE
=
"
DummyModule
"
DUMMY_LOGITPROC_MODULE
=
"
tests.v1.logits_processors.utils
"
DUMMY_LOGITPROC_FQCN
=
f
"
{
DUMMY_LOGITPROC_MODULE
}
:DummyLogitsProcessor"
...
...
tests/v1/metrics/test_stats.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.v1.metrics.stats
import
IterationStats
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.metrics.stats
import
IterationStats
,
RequestStateStats
def
test_iteration_stats_repr
():
iteration_stats
=
IterationStats
()
assert
repr
(
iteration_stats
).
startswith
(
"IterationStats("
)
def
test_prefill_kv_computed_with_cache
():
"""Test that prefill KV compute correctly excludes cached tokens."""
iteration_stats
=
IterationStats
()
req_stats
=
RequestStateStats
(
arrival_time
=
0.0
)
req_stats
.
scheduled_ts
=
0.1
req_stats
.
first_token_ts
=
0.5
req_stats
.
last_token_ts
=
5.0
req_stats
.
num_generation_tokens
=
50
# Case 1: With prefix cache (1200 tokens cached)
iteration_stats
.
update_from_finished_request
(
finish_reason
=
FinishReason
.
STOP
,
num_prompt_tokens
=
10000
,
max_tokens_param
=
100
,
req_stats
=
req_stats
,
num_cached_tokens
=
1200
,
)
finished_req
=
iteration_stats
.
finished_requests
[
0
]
assert
finished_req
.
num_prompt_tokens
==
10000
assert
finished_req
.
num_cached_tokens
==
1200
# Verify calculation: prefill KV = prompt tokens - cached tokens
prefill_kv_computed
=
finished_req
.
num_prompt_tokens
-
max
(
finished_req
.
num_cached_tokens
,
0
)
assert
prefill_kv_computed
==
8800
# 10000 - 1200
def
test_prefill_kv_computed_no_cache
():
"""Test prefill KV compute without prefix caching."""
iteration_stats
=
IterationStats
()
req_stats
=
RequestStateStats
(
arrival_time
=
0.0
)
req_stats
.
scheduled_ts
=
0.1
req_stats
.
first_token_ts
=
0.5
req_stats
.
last_token_ts
=
2.0
req_stats
.
num_generation_tokens
=
10
# Case 2: No prefix cache
iteration_stats
.
update_from_finished_request
(
finish_reason
=
FinishReason
.
STOP
,
num_prompt_tokens
=
2000
,
max_tokens_param
=
100
,
req_stats
=
req_stats
,
num_cached_tokens
=
0
,
)
finished_req
=
iteration_stats
.
finished_requests
[
0
]
assert
finished_req
.
num_prompt_tokens
==
2000
assert
finished_req
.
num_cached_tokens
==
0
# Verify calculation: prefill KV = full prompt when no cache
prefill_kv_computed
=
finished_req
.
num_prompt_tokens
-
max
(
finished_req
.
num_cached_tokens
,
0
)
assert
prefill_kv_computed
==
2000
def
test_prefill_kv_computed_edge_cases
():
"""Test edge cases for prefill KV compute calculation."""
iteration_stats
=
IterationStats
()
req_stats
=
RequestStateStats
(
arrival_time
=
0.0
)
req_stats
.
scheduled_ts
=
0.1
req_stats
.
first_token_ts
=
0.5
req_stats
.
last_token_ts
=
1.0
req_stats
.
num_generation_tokens
=
1
# Case 3: Negative num_cached_tokens (shouldn't happen, but handle gracefully)
iteration_stats
.
update_from_finished_request
(
finish_reason
=
FinishReason
.
STOP
,
num_prompt_tokens
=
100
,
max_tokens_param
=
10
,
req_stats
=
req_stats
,
num_cached_tokens
=-
1
,
)
finished_req
=
iteration_stats
.
finished_requests
[
0
]
# max() should handle negative values
prefill_kv_computed
=
finished_req
.
num_prompt_tokens
-
max
(
finished_req
.
num_cached_tokens
,
0
)
assert
prefill_kv_computed
==
100
# Should treat negative as 0
# Case 4: All tokens cached (shouldn't happen in practice)
iteration_stats2
=
IterationStats
()
iteration_stats2
.
update_from_finished_request
(
finish_reason
=
FinishReason
.
STOP
,
num_prompt_tokens
=
100
,
max_tokens_param
=
10
,
req_stats
=
req_stats
,
num_cached_tokens
=
100
,
)
finished_req2
=
iteration_stats2
.
finished_requests
[
0
]
prefill_kv_computed2
=
finished_req2
.
num_prompt_tokens
-
max
(
finished_req2
.
num_cached_tokens
,
0
)
assert
prefill_kv_computed2
==
0
# All cached, nothing computed
tests/v1/spec_decode/test_eagle.py
View file @
8d75f22e
...
...
@@ -339,7 +339,7 @@ def test_load_model(
"multi-token eagle spec decode on current platform"
)
if
attn_backend
==
"
FLASH_ATTN
"
and
current_platform
.
is_rocm
():
if
attn_backend
==
"
ROCM_AITER_FA
"
and
current_platform
.
is_rocm
():
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
# Setup draft model mock
...
...
@@ -436,7 +436,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
"because it requires special input mocking."
)
if
attn_backend
==
"
FLASH_ATTN
"
and
current_platform
.
is_rocm
():
if
attn_backend
==
"
ROCM_AITER_FA
"
and
current_platform
.
is_rocm
():
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
# Use GPU device
...
...
@@ -543,6 +543,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
attn_metadata_builder_cls
,
_
=
try_get_attention_backend
(
AttentionBackendEnum
.
TREE_ATTN
)
elif
attn_backend
==
"ROCM_AITER_FA"
:
attn_metadata_builder_cls
,
_
=
try_get_attention_backend
(
AttentionBackendEnum
.
ROCM_AITER_FA
)
else
:
raise
ValueError
(
f
"Unsupported attention backend:
{
attn_backend
}
"
)
...
...
tests/v1/spec_decode/test_max_len.py
View file @
8d75f22e
...
...
@@ -47,7 +47,7 @@ def test_eagle_max_len(
"multi-token eagle spec decode on current platform"
)
if
attn_backend
==
"
FLASH_ATTN
"
and
current_platform
.
is_rocm
():
if
attn_backend
==
"
ROCM_AITER_FA
"
and
current_platform
.
is_rocm
():
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
llm
=
LLM
(
...
...
@@ -82,7 +82,7 @@ def test_eagle_max_len(
len
(
o
.
prompt_token_ids
)
<
80
<
len
(
o
.
prompt_token_ids
)
+
len
(
o
.
outputs
[
0
].
token_ids
)
<
200
<
=
200
),
(
"This test is only meaningful if the output "
"is longer than the eagle max length"
...
...
tests/v1/spec_decode/test_speculators_eagle3.py
View file @
8d75f22e
...
...
@@ -5,6 +5,7 @@ import torch
from
vllm.config
import
SpeculativeConfig
from
vllm.model_executor.models.interfaces
import
supports_eagle3
from
vllm.platforms
import
current_platform
@
pytest
.
mark
.
parametrize
(
...
...
@@ -21,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3
pytest
.
param
(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16"
,
id
=
"qwen3-eagle3-speculator-w4a16-verifier"
,
marks
=
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"The tests are skipped on rocm platform."
,
),
),
],
)
...
...
tests/v1/spec_decode/test_tree_attention.py
View file @
8d75f22e
...
...
@@ -88,8 +88,8 @@ def forward_attention(
query_start_loc
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc
.
cpu
(),
seq_lens
=
seq_lens
,
seq_lens_cpu
=
seq_lens
.
cpu
(),
num_computed_tokens_cpu
=
context_lens
.
cpu
(),
_
seq_lens_cpu
=
seq_lens
.
cpu
(),
_
num_computed_tokens_cpu
=
context_lens
.
cpu
(),
num_reqs
=
batch_size
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
...
...
tests/v1/structured_output/test_backend_guidance.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
from
concurrent.futures
import
Future
import
pytest
from
transformers
import
AutoTokenizer
from
vllm.config
import
StructuredOutputsConfig
,
VllmConfig
from
vllm.config.model
import
ModelConfig
from
vllm.config.parallel
import
ParallelConfig
from
vllm.config.speculative
import
SpeculativeConfig
from
vllm.sampling_params
import
SamplingParams
,
StructuredOutputsParams
from
vllm.v1.request
import
Request
...
...
@@ -116,3 +121,72 @@ def test_grammar_bitmask_with_specdec():
)
# EOS not the final token
grammar_bitmask
(
request
,
prompt
[
i
:])
# EOS not present
grammar_bitmask
(
request
,
prompt
[
i
:]
+
[
tokenizer
.
eos_token_id
])
@
pytest
.
mark
.
parametrize
(
"async_grammar"
,
[
True
,
False
])
def
test_grammar_init_async_and_sync
(
async_grammar
):
"""Test grammar initialization works correctly in both async and sync modes.
This test validates that the distributed_executor_backend config option
correctly controls whether grammar compilation happens asynchronously
(via executor.submit) or synchronously. When set to "external_launcher",
grammar compilation is synchronous to avoid deadlocks.
"""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TOKENIZER
)
prompt
=
tokenizer
.
encode
(
'{"a": "b"}'
)
# Use "external_launcher" for sync mode, None for async mode
executor_backend
=
None
if
async_grammar
else
"external_launcher"
vllm_config
=
VllmConfig
(
model_config
=
ModelConfig
(
tokenizer
=
TOKENIZER
),
structured_outputs_config
=
StructuredOutputsConfig
(
backend
=
"guidance"
),
parallel_config
=
ParallelConfig
(
distributed_executor_backend
=
executor_backend
),
)
structured_output_manager
=
StructuredOutputManager
(
vllm_config
)
sampling_params
=
SamplingParams
(
structured_outputs
=
StructuredOutputsParams
(
json
=
'{"type": "object"}'
,
),
)
sampling_params
.
structured_outputs
.
_backend
=
"guidance"
request
=
Request
(
"test_request"
,
prompt_token_ids
=
prompt
,
sampling_params
=
sampling_params
,
pooling_params
=
None
,
eos_token_id
=
tokenizer
.
eos_token_id
,
)
structured_output_manager
.
grammar_init
(
request
)
# Check the internal _grammar type immediately after init
# Before _check_grammar_completion is called, async mode should have a Future
raw_grammar
=
request
.
structured_output_request
.
_grammar
if
async_grammar
:
assert
isinstance
(
raw_grammar
,
Future
),
(
"Async mode should store a Future before completion"
)
else
:
assert
not
isinstance
(
raw_grammar
,
Future
),
(
"Sync mode should store the grammar directly, not a Future"
)
# Wait for grammar to be ready (handles both async and sync cases)
start_time
=
time
.
time
()
while
not
request
.
structured_output_request
.
_check_grammar_completion
():
if
time
.
time
()
-
start_time
>
5
:
# 5-second timeout
pytest
.
fail
(
"Grammar compilation timed out"
)
time
.
sleep
(
0.01
)
# After completion, _grammar should no longer be a Future
assert
not
isinstance
(
request
.
structured_output_request
.
_grammar
,
Future
)
# Verify grammar is properly initialized and functional
grammar
=
request
.
structured_output_request
.
grammar
assert
grammar
is
not
None
assert
not
grammar
.
is_terminated
()
# Verify the grammar can accept valid tokens
assert
grammar
.
accept_tokens
(
request
.
request_id
,
prompt
)
tests/v1/structured_output/test_reasoning_structured_output.py
View file @
8d75f22e
...
...
@@ -70,6 +70,7 @@ class TestReasoningStructuredOutput:
request
.
use_structured_output
=
True
request
.
prompt_token_ids
=
[
1
,
2
,
3
,
4
,
5
]
request
.
all_token_ids
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
]
request
.
num_computed_tokens
=
5
return
request
def
test_should_fill_bitmask_with_enable_in_reasoning
(
...
...
tests/v1/test_serial_utils.py
View file @
8d75f22e
...
...
@@ -104,22 +104,31 @@ class MyRequest(msgspec.Struct):
def
test_multimodal_kwargs
():
e1
=
MultiModalFieldElem
(
"audio"
,
"a0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
bfloat16
),
MultiModalBatchedField
()
"audio"
,
"a0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
bfloat16
),
MultiModalBatchedField
(),
)
e2
=
MultiModalFieldElem
(
"video"
,
"v0"
,
[
torch
.
zeros
(
1000
,
dtype
=
torch
.
int8
)
for
_
in
range
(
4
)],
MultiModalFlatField
([[
slice
(
1
,
2
,
3
),
slice
(
4
,
5
,
6
)],
[
slice
(
None
,
2
)]],
0
),
MultiModalFlatField
(
slices
=
[[
slice
(
1
,
2
,
3
),
slice
(
4
,
5
,
6
)],
[
slice
(
None
,
2
)]],
dim
=
0
,
),
)
e3
=
MultiModalFieldElem
(
"image"
,
"i0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalSharedField
(
4
)
"image"
,
"i0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalSharedField
(
batch_size
=
4
),
)
e4
=
MultiModalFieldElem
(
"image"
,
"i1"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalFlatField
([
slice
(
1
,
2
,
3
),
slice
(
4
,
5
,
6
)],
2
),
MultiModalFlatField
(
slices
=
[
slice
(
1
,
2
,
3
),
slice
(
4
,
5
,
6
)],
dim
=
2
),
)
audio
=
MultiModalKwargsItem
.
from_elems
([
e1
])
video
=
MultiModalKwargsItem
.
from_elems
([
e2
])
...
...
@@ -138,8 +147,8 @@ def test_multimodal_kwargs():
total_len
=
sum
(
memoryview
(
x
).
cast
(
"B"
).
nbytes
for
x
in
encoded
)
# expected total encoding length, should be 143
06
, +-20 for minor changes
assert
14
2
75
<=
total_len
<=
14
3
25
# expected total encoding length, should be 143
95
, +-20 for minor changes
assert
14
3
75
<=
total_len
<=
14
4
25
decoded
=
decoder
.
decode
(
encoded
).
mm
[
0
]
assert
isinstance
(
decoded
,
MultiModalKwargsItems
)
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
8d75f22e
...
...
@@ -6,8 +6,10 @@ import pytest
import
torch
from
vllm.attention.backends.abstract
import
MultipleOf
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.layer
import
Attention
from
vllm.config
import
(
AttentionConfig
,
CacheConfig
,
ModelConfig
,
ParallelConfig
,
...
...
@@ -761,7 +763,11 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert
kv_cache_config_after_init
.
kv_cache_groups
[
0
].
layer_names
[
1
]
==
layer_1
def
test_hybrid_attention_mamba_tensor_shapes
(
monkeypatch
):
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Attention backend FLASHINFER is not supported on ROCm."
,
)
def
test_hybrid_attention_mamba_tensor_shapes
():
"""
The GPU model runner creates different views into the
KVCacheTensors for the attention and mamba layers
...
...
@@ -802,11 +808,13 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
cache_dtype
=
"auto"
,
)
parallel_config
=
ParallelConfig
()
attention_config
=
AttentionConfig
(
backend
=
AttentionBackendEnum
.
FLASHINFER
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
cache_config
,
scheduler_config
=
scheduler_config
,
parallel_config
=
parallel_config
,
attention_config
=
attention_config
,
)
layer_0
=
"model.layers.0.self_attn.attn"
...
...
@@ -816,8 +824,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
layer_4
=
"model.layers.4.mixer"
layer_5
=
"model.layers.5.mixer"
with
set_current_vllm_config
(
vllm_config
),
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLASHINFER"
)
with
set_current_vllm_config
(
vllm_config
):
hf_config
=
vllm_config
.
model_config
.
hf_config
fwd_context
=
{}
for
key
in
[
layer_0
,
layer_1
]:
...
...
@@ -847,10 +854,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
)
# suppress var not used error
assert
fwd_context
is
not
None
vllm_ctx
=
vllm_config
.
compilation_config
.
static_forward_context
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLASHINFER"
)
vllm_ctx
=
vllm_config
.
compilation_config
.
static_forward_context
runner
=
GPUModelRunner
(
vllm_config
,
DEVICE
)
kv_cache_spec
=
runner
.
get_kv_cache_spec
()
...
...
@@ -861,94 +865,94 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
)[
0
]
runner
.
initialize_kv_cache
(
kv_cache_config
)
# random partition of blocks
# blocks0 will be assigned to attention layers
# blocks1 will be assigned to mamba layers
num_blocks
=
kv_cache_config
.
num_blocks
ind
=
np
.
arange
(
num_blocks
)
np
.
random
.
shuffle
(
ind
)
blocks0
,
blocks1
=
ind
[:
(
num_blocks
//
2
)],
ind
[(
num_blocks
//
2
)
:]
# random partition of blocks
# blocks0 will be assigned to attention layers
# blocks1 will be assigned to mamba layers
num_blocks
=
kv_cache_config
.
num_blocks
ind
=
np
.
arange
(
num_blocks
)
np
.
random
.
shuffle
(
ind
)
blocks0
,
blocks1
=
ind
[:
(
num_blocks
//
2
)],
ind
[(
num_blocks
//
2
)
:]
attn_shape
=
vllm_ctx
[
layer_0
].
kv_cache
[
0
].
shape
conv_shape
=
vllm_ctx
[
layer_2
].
kv_cache
[
0
][
0
].
shape
ssm_shape
=
vllm_ctx
[
layer_2
].
kv_cache
[
0
][
1
].
shape
attn_shape
=
vllm_ctx
[
layer_0
].
kv_cache
[
0
].
shape
conv_shape
=
vllm_ctx
[
layer_2
].
kv_cache
[
0
][
0
].
shape
ssm_shape
=
vllm_ctx
[
layer_2
].
kv_cache
[
0
][
1
].
shape
# assert we are using FlashInfer
assert
attn_shape
[
0
]
%
num_blocks
==
0
block_split_ratio
=
attn_shape
[
0
]
//
num_blocks
# assert we are using FlashInfer
assert
attn_shape
[
0
]
%
num_blocks
==
0
block_split_ratio
=
attn_shape
[
0
]
//
num_blocks
# use small blocks for testing to avoid memory issues
test_block_size
=
min
(
2
,
len
(
blocks0
),
len
(
blocks1
))
# use small blocks for testing to avoid memory issues
test_block_size
=
min
(
2
,
len
(
blocks0
),
len
(
blocks1
))
# use non-overlapping blocks to avoid data contamination
# Split kernel blocks: first half for attention, second half for mamba
mid_point
=
num_blocks
//
2
# use non-overlapping blocks to avoid data contamination
# Split kernel blocks: first half for attention, second half for mamba
mid_point
=
num_blocks
//
2
# attention uses kernel blocks from first half (mapped to logical blocks)
kv_blocks_for_attention
=
np
.
array
([
0
,
1
])[:
test_block_size
]
# attention uses kernel blocks from first half (mapped to logical blocks)
kv_blocks_for_attention
=
np
.
array
([
0
,
1
])[:
test_block_size
]
# mamba uses kernel blocks from second half
kv_blocks_for_mamba
=
np
.
array
([
mid_point
,
mid_point
+
1
])[:
test_block_size
]
# mamba uses kernel blocks from second half
kv_blocks_for_mamba
=
np
.
array
([
mid_point
,
mid_point
+
1
])[:
test_block_size
]
# create small constant tensors for testing with corrected shapes
# attention: [block_size, ...] starting from dimension 2
attn_constant_shape
=
attn_shape
[
2
:]
conv_constant_shape
=
conv_shape
[
1
:]
ssm_constant_shape
=
ssm_shape
[
1
:]
# create small constant tensors for testing with corrected shapes
# attention: [block_size, ...] starting from dimension 2
attn_constant_shape
=
attn_shape
[
2
:]
conv_constant_shape
=
conv_shape
[
1
:]
ssm_constant_shape
=
ssm_shape
[
1
:]
attn_blocks_constant
=
torch
.
full
(
(
test_block_size
,
*
attn_constant_shape
),
device
=
DEVICE
,
fill_value
=
3.33
)
conv_blocks_constant
=
torch
.
full
(
(
test_block_size
,
*
conv_constant_shape
),
device
=
DEVICE
,
fill_value
=
6.66
)
ssm_blocks_constant
=
torch
.
full
(
(
test_block_size
,
*
ssm_constant_shape
),
device
=
DEVICE
,
fill_value
=
9.99
)
attn_blocks_constant
=
torch
.
full
(
(
test_block_size
,
*
attn_constant_shape
),
device
=
DEVICE
,
fill_value
=
3.33
)
conv_blocks_constant
=
torch
.
full
(
(
test_block_size
,
*
conv_constant_shape
),
device
=
DEVICE
,
fill_value
=
6.66
)
ssm_blocks_constant
=
torch
.
full
(
(
test_block_size
,
*
ssm_constant_shape
),
device
=
DEVICE
,
fill_value
=
9.99
)
# Fill attention blocks with constants using kv block indices
kernel_blocks_for_attention
=
kv_blocks_for_attention
*
block_split_ratio
for
layer
in
[
layer_0
,
layer_1
]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
for
i
,
kernel_block
in
enumerate
(
kernel_blocks_for_attention
):
vllm_ctx
[
layer
].
kv_cache
[
0
][
kernel_block
,
:]
=
attn_blocks_constant
[
i
]
# fill mamba blocks with constants using kernel block indices
for
layer
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
# mamba: kv_cache[0][component][kernel_block_idx, ...]
for
i
,
kv_block
in
enumerate
(
kv_blocks_for_mamba
):
vllm_ctx
[
layer
].
kv_cache
[
0
][
0
][
kv_block
,
:]
=
conv_blocks_constant
[
i
]
vllm_ctx
[
layer
].
kv_cache
[
0
][
1
][
kv_block
,
:]
=
ssm_blocks_constant
[
i
]
# verify attention and mamba contents are correct
for
layer
in
[
layer_0
,
layer_1
]:
for
i
,
kernel_block
in
enumerate
(
kernel_blocks_for_attention
):
actual_kv
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
kernel_block
,
:]
expected
=
attn_blocks_constant
[
i
]
# Check K and V separately
assert
torch
.
equal
(
actual_kv
[
0
],
expected
)
assert
torch
.
equal
(
actual_kv
[
1
],
expected
)
for
layer
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
for
i
,
kv_block
in
enumerate
(
kv_blocks_for_mamba
):
actual_conv
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
0
][
kv_block
,
:]
actual_ssm
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
1
][
kv_block
,
:]
expected_conv
=
conv_blocks_constant
[
i
]
expected_ssm
=
ssm_blocks_constant
[
i
]
assert
torch
.
equal
(
actual_conv
,
expected_conv
)
assert
torch
.
equal
(
actual_ssm
,
expected_ssm
)
for
layer
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
for
i
,
kv_block
in
enumerate
(
kv_blocks_for_mamba
):
actual_conv
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
0
][
kv_block
,
:]
actual_ssm
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
1
][
kv_block
,
:]
expected_conv
=
conv_blocks_constant
[
i
]
expected_ssm
=
ssm_blocks_constant
[
i
]
assert
torch
.
equal
(
actual_conv
,
expected_conv
)
assert
torch
.
equal
(
actual_ssm
,
expected_ssm
)
# Fill attention blocks with constants using kv block indices
kernel_blocks_for_attention
=
kv_blocks_for_attention
*
block_split_ratio
for
layer
in
[
layer_0
,
layer_1
]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
for
i
,
kernel_block
in
enumerate
(
kernel_blocks_for_attention
):
vllm_ctx
[
layer
].
kv_cache
[
0
][
kernel_block
,
:]
=
attn_blocks_constant
[
i
]
# fill mamba blocks with constants using kernel block indices
for
layer
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
# mamba: kv_cache[0][component][kernel_block_idx, ...]
for
i
,
kv_block
in
enumerate
(
kv_blocks_for_mamba
):
vllm_ctx
[
layer
].
kv_cache
[
0
][
0
][
kv_block
,
:]
=
conv_blocks_constant
[
i
]
vllm_ctx
[
layer
].
kv_cache
[
0
][
1
][
kv_block
,
:]
=
ssm_blocks_constant
[
i
]
# verify attention and mamba contents are correct
for
layer
in
[
layer_0
,
layer_1
]:
for
i
,
kernel_block
in
enumerate
(
kernel_blocks_for_attention
):
actual_kv
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
kernel_block
,
:]
expected
=
attn_blocks_constant
[
i
]
# Check K and V separately
assert
torch
.
equal
(
actual_kv
[
0
],
expected
)
assert
torch
.
equal
(
actual_kv
[
1
],
expected
)
for
layer
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
for
i
,
kv_block
in
enumerate
(
kv_blocks_for_mamba
):
actual_conv
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
0
][
kv_block
,
:]
actual_ssm
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
1
][
kv_block
,
:]
expected_conv
=
conv_blocks_constant
[
i
]
expected_ssm
=
ssm_blocks_constant
[
i
]
assert
torch
.
equal
(
actual_conv
,
expected_conv
)
assert
torch
.
equal
(
actual_ssm
,
expected_ssm
)
for
layer
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
for
i
,
kv_block
in
enumerate
(
kv_blocks_for_mamba
):
actual_conv
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
0
][
kv_block
,
:]
actual_ssm
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
1
][
kv_block
,
:]
expected_conv
=
conv_blocks_constant
[
i
]
expected_ssm
=
ssm_blocks_constant
[
i
]
assert
torch
.
equal
(
actual_conv
,
expected_conv
)
assert
torch
.
equal
(
actual_ssm
,
expected_ssm
)
def
test_hybrid_block_table_initialization
():
...
...
tests/v1/worker/test_gpu_profiler.py
View file @
8d75f22e
...
...
@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
vllm.envs
as
envs
from
vllm.profiler.
gpu_profil
er
import
WorkerProfiler
from
vllm.config
import
ProfilerConfig
from
vllm.profiler.
wrapp
er
import
WorkerProfiler
class
ConcreteWorkerProfiler
(
WorkerProfiler
):
...
...
@@ -11,11 +11,11 @@ class ConcreteWorkerProfiler(WorkerProfiler):
A basic implementation of a worker profiler for testing purposes.
"""
def
__init__
(
self
):
def
__init__
(
self
,
profiler_config
:
ProfilerConfig
):
self
.
start_call_count
=
0
self
.
stop_call_count
=
0
self
.
should_fail_start
=
False
super
().
__init__
()
super
().
__init__
(
profiler_config
)
def
_start
(
self
)
->
None
:
if
self
.
should_fail_start
:
...
...
@@ -26,17 +26,19 @@ class ConcreteWorkerProfiler(WorkerProfiler):
self
.
stop_call_count
+=
1
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_mocks
():
"""Fixture to reset mocks and env variables before each test."""
envs
.
VLLM_PROFILER_DELAY_ITERS
=
0
envs
.
VLLM_PROFILER_MAX_ITERS
=
0
@
pytest
.
fixture
def
default_profiler_config
():
return
ProfilerConfig
(
profiler
=
"torch"
,
torch_profiler_dir
=
"/tmp/mock"
,
delay_iterations
=
0
,
max_iterations
=
0
,
)
def
test_immediate_start_stop
():
def
test_immediate_start_stop
(
default_profiler_config
):
"""Test standard start without delay."""
profiler
=
ConcreteWorkerProfiler
()
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
profiler
.
start
()
assert
profiler
.
_running
is
True
assert
profiler
.
_active
is
True
...
...
@@ -48,10 +50,10 @@ def test_immediate_start_stop():
assert
profiler
.
stop_call_count
==
1
def
test_delayed_start
():
def
test_delayed_start
(
default_profiler_config
):
"""Test that profiler waits for N steps before actually starting."""
envs
.
VLLM_PROFILER_DELAY_ITERS
=
2
profiler
=
ConcreteWorkerProfiler
()
default_profiler_config
.
delay_iterations
=
2
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
# User requests start
profiler
.
start
()
...
...
@@ -71,10 +73,10 @@ def test_delayed_start():
assert
profiler
.
start_call_count
==
1
def
test_max_iterations
():
def
test_max_iterations
(
default_profiler_config
):
"""Test that profiler stops automatically after max iterations."""
envs
.
VLLM_PROFILER_MAX_ITERS
=
2
profiler
=
ConcreteWorkerProfiler
()
default_profiler_config
.
max_iterations
=
2
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
profiler
.
start
()
assert
profiler
.
_running
is
True
...
...
@@ -95,12 +97,11 @@ def test_max_iterations():
assert
profiler
.
stop_call_count
==
1
def
test_delayed_start_and_max_iters
():
def
test_delayed_start_and_max_iters
(
default_profiler_config
):
"""Test combined delayed start and max iterations."""
envs
.
VLLM_PROFILER_DELAY_ITERS
=
2
envs
.
VLLM_PROFILER_MAX_ITERS
=
2
profiler
=
ConcreteWorkerProfiler
()
default_profiler_config
.
delay_iterations
=
2
default_profiler_config
.
max_iterations
=
2
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
profiler
.
start
()
# Step 1
...
...
@@ -127,9 +128,9 @@ def test_delayed_start_and_max_iters():
assert
profiler
.
stop_call_count
==
1
def
test_idempotency
():
def
test_idempotency
(
default_profiler_config
):
"""Test that calling start/stop multiple times doesn't break logic."""
profiler
=
ConcreteWorkerProfiler
()
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
# Double Start
profiler
.
start
()
...
...
@@ -142,10 +143,10 @@ def test_idempotency():
assert
profiler
.
stop_call_count
==
1
# Should only stop once
def
test_step_inactive
():
def
test_step_inactive
(
default_profiler_config
):
"""Test that stepping while inactive does nothing."""
envs
.
VLLM_PROFILER_DELAY_ITERS
=
2
profiler
=
ConcreteWorkerProfiler
()
default_profiler_config
.
delay_iterations
=
2
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
# Not started yet
profiler
.
step
()
...
...
@@ -155,9 +156,9 @@ def test_step_inactive():
assert
profiler
.
start_call_count
==
0
def
test_start_failure
():
def
test_start_failure
(
default_profiler_config
):
"""Test behavior when the underlying _start method raises exception."""
profiler
=
ConcreteWorkerProfiler
()
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
profiler
.
should_fail_start
=
True
profiler
.
start
()
...
...
@@ -168,9 +169,9 @@ def test_start_failure():
assert
profiler
.
start_call_count
==
0
# Logic failed inside start
def
test_shutdown
():
def
test_shutdown
(
default_profiler_config
):
"""Test that shutdown calls stop only if running."""
profiler
=
ConcreteWorkerProfiler
()
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
# Case 1: Not running
profiler
.
shutdown
()
...
...
@@ -182,10 +183,10 @@ def test_shutdown():
assert
profiler
.
stop_call_count
==
1
def
test_mixed_delay_and_stop
():
def
test_mixed_delay_and_stop
(
default_profiler_config
):
"""Test manual stop during the delay period."""
envs
.
VLLM_PROFILER_DELAY_ITERS
=
5
profiler
=
ConcreteWorkerProfiler
()
default_profiler_config
.
delay_iterations
=
5
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
profiler
.
start
()
profiler
.
step
()
...
...
tools/ep_kernels/install_python_libraries.sh
View file @
8d75f22e
...
...
@@ -10,9 +10,10 @@ set -ex
CUDA_HOME
=
${
CUDA_HOME
:-
/usr/local/cuda
}
PPLX_COMMIT_HASH
=
${
PPLX_COMMIT_HASH
:-
"12cecfd"
}
DEEPEP_COMMIT_HASH
=
${
DEEPEP_COMMIT_HASH
:-
"73b6ea4"
}
NVSHMEM_VER
=
3.3.
9
NVSHMEM_VER
=
3.3.
24
# Suppports both CUDA 12 and 13
WORKSPACE
=
${
WORKSPACE
:-
$(
pwd
)
/ep_kernels_workspace
}
MODE
=
${
MODE
:-
install
}
CUDA_VERSION_MAJOR
=
$(
${
CUDA_HOME
}
/bin/nvcc
--version
| egrep
-o
"release [0-9]+"
|
cut
-d
' '
-f
2
)
# Parse arguments
while
[[
$#
-gt
0
]]
;
do
...
...
@@ -75,11 +76,9 @@ ARCH=$(uname -m)
case
"
${
ARCH
,,
}
"
in
x86_64|amd64
)
NVSHMEM_SUBDIR
=
"linux-x86_64"
NVSHMEM_FILE
=
"libnvshmem-linux-x86_64-
${
NVSHMEM_VER
}
_cuda12-archive.tar.xz"
;;
aarch64|arm64
)
NVSHMEM_SUBDIR
=
"linux-sbsa"
NVSHMEM_FILE
=
"libnvshmem-linux-sbsa-
${
NVSHMEM_VER
}
_cuda12-archive.tar.xz"
;;
*
)
echo
"Unsupported architecture:
${
ARCH
}
"
>
&2
...
...
@@ -87,6 +86,7 @@ case "${ARCH,,}" in
;;
esac
NVSHMEM_FILE
=
"libnvshmem-
${
NVSHMEM_SUBDIR
}
-
${
NVSHMEM_VER
}
_cuda
${
CUDA_VERSION_MAJOR
}
-archive.tar.xz"
NVSHMEM_URL
=
"https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/
${
NVSHMEM_SUBDIR
}
/
${
NVSHMEM_FILE
}
"
pushd
"
$WORKSPACE
"
...
...
@@ -142,13 +142,6 @@ clone_repo() {
fi
}
deepep_cuda13_patch
()
{
cuda_version_major
=
$(
${
CUDA_HOME
}
/bin/nvcc
--version
| egrep
-o
"release [0-9]+"
|
cut
-d
' '
-f
2
)
if
[
${
cuda_version_major
}
-ge
13
]
;
then
sed
-i
"s|f'{nvshmem_dir}/include']|f'{nvshmem_dir}/include', '
${
CUDA_HOME
}
/include/cccl']|"
"setup.py"
fi
}
do_build
()
{
local
repo
=
$1
local
name
=
$2
...
...
@@ -160,8 +153,9 @@ do_build() {
clone_repo
"
$repo
"
"
$name
"
"
$key
"
"
$commit
"
cd
"
$name
"
if
[
"
$name
"
==
"DeepEP"
]
;
then
deepep_cuda13_patch
# DeepEP CUDA 13 patch
if
[[
"
$name
"
==
"DeepEP"
&&
"
${
CUDA_VERSION_MAJOR
}
"
-ge
13
]]
;
then
sed
-i
"s|f'{nvshmem_dir}/include']|f'{nvshmem_dir}/include', '
${
CUDA_HOME
}
/include/cccl']|"
"setup.py"
fi
if
[
"
$MODE
"
=
"install"
]
;
then
...
...
use_existing_torch.py
View file @
8d75f22e
...
...
@@ -3,9 +3,7 @@
import
glob
requires_files
=
glob
.
glob
(
"requirements/*.txt"
)
requires_files
+=
[
"pyproject.toml"
]
for
file
in
requires_files
:
for
file
in
(
*
glob
.
glob
(
"requirements/*.txt"
),
"pyproject.toml"
):
print
(
f
">>> cleaning
{
file
}
"
)
with
open
(
file
)
as
f
:
lines
=
f
.
readlines
()
...
...
@@ -17,5 +15,4 @@ for file in requires_files:
f
.
write
(
line
)
else
:
print
(
line
.
strip
())
print
(
f
"<<< done cleaning
{
file
}
"
)
print
()
print
(
f
"<<< done cleaning
{
file
}
\n
"
)
vllm/_aiter_ops.py
View file @
8d75f22e
...
...
@@ -9,6 +9,8 @@ import vllm.envs as envs
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
,
is_torch_equal_or_newer
_FP8_DTYPE
=
current_platform
.
fp8_dtype
()
def
is_aiter_found
()
->
bool
:
from
importlib.util
import
find_spec
...
...
@@ -22,6 +24,15 @@ def is_aiter_found() -> bool:
# we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND
=
is_aiter_found
()
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if
IS_AITER_FOUND
:
from
aiter
import
dtypes
AITER_FP8_DTYPE
=
dtypes
.
fp8
def
if_aiter_supported
(
func
:
Callable
)
->
Callable
:
"""Decorator that only executes the function if
...
...
@@ -43,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable:
return
wrapper
def
_rocm_aiter_group_fp8_quant_impl
(
x
:
torch
.
Tensor
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
shape
[
-
1
]
%
group_size
==
0
,
"Input shape must be divisible by group size"
from
aiter
import
QuantType
,
dtypes
,
get_hip_quant
aiter_per1x128_quant
=
get_hip_quant
(
QuantType
.
per_1x128
)
return
aiter_per1x128_quant
(
x
.
contiguous
(),
quant_dtype
=
dtypes
.
fp8
)
def
_rocm_aiter_group_fp8_quant_fake
(
x
:
torch
.
Tensor
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter
import
dtypes
M
,
N
=
x
.
shape
x_fp8
=
torch
.
empty
((
M
,
N
),
dtype
=
dtypes
.
fp8
,
device
=
x
.
device
)
out_bs
=
torch
.
empty
(
(
M
,
(
N
+
group_size
-
1
)
//
group_size
,
),
dtype
=
torch
.
float32
,
device
=
x
.
device
,
)
return
x_fp8
,
out_bs
def
_rocm_aiter_fused_moe_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
...
@@ -283,6 +264,28 @@ def _rocm_aiter_grouped_topk_fake(
pass
# Cache whether aiter supports FP8 MLA parameters
_AITER_MLA_SUPPORTS_FP8
:
bool
|
None
=
None
def
_check_aiter_mla_fp8_support
()
->
bool
:
"""Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters."""
global
_AITER_MLA_SUPPORTS_FP8
if
_AITER_MLA_SUPPORTS_FP8
is
None
:
try
:
import
inspect
from
aiter.mla
import
mla_decode_fwd
sig
=
inspect
.
signature
(
mla_decode_fwd
)
_AITER_MLA_SUPPORTS_FP8
=
(
"q_scale"
in
sig
.
parameters
and
"kv_scale"
in
sig
.
parameters
)
except
Exception
:
_AITER_MLA_SUPPORTS_FP8
=
False
return
_AITER_MLA_SUPPORTS_FP8
def
_rocm_aiter_mla_decode_fwd_impl
(
q
:
torch
.
Tensor
,
kv_buffer
:
torch
.
Tensor
,
...
...
@@ -299,6 +302,16 @@ def _rocm_aiter_mla_decode_fwd_impl(
)
->
None
:
from
aiter.mla
import
mla_decode_fwd
kwargs
=
{
"sm_scale"
:
sm_scale
,
"logit_cap"
:
logit_cap
,
}
# Only pass q_scale and kv_scale if the aiter library supports them
if
_check_aiter_mla_fp8_support
():
kwargs
[
"q_scale"
]
=
q_scale
kwargs
[
"kv_scale"
]
=
kv_scale
mla_decode_fwd
(
q
,
kv_buffer
.
view
(
-
1
,
1
,
1
,
q
.
shape
[
-
1
]),
...
...
@@ -308,10 +321,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
kv_indices
,
kv_last_page_lens
,
max_seqlen_qo
,
sm_scale
=
sm_scale
,
logit_cap
=
logit_cap
,
q_scale
=
q_scale
,
kv_scale
=
kv_scale
,
**
kwargs
,
)
...
...
@@ -438,6 +448,195 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
return
torch
.
empty_like
(
x
),
torch
.
empty_like
(
residual
)
def
_rocm_aiter_per_tensor_quant_impl
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter.ops.quant
import
per_tensor_quant_hip
return
per_tensor_quant_hip
(
x
,
scale
,
quant_dtype
)
def
_rocm_aiter_per_tensor_quant_fake
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
empty_like
(
x
,
dtype
=
quant_dtype
),
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
x
.
device
)
def
_rocm_aiter_per_token_quant_impl
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter.ops.quant
import
dynamic_per_token_scaled_quant
assert
quant_dtype
in
[
torch
.
int8
,
_FP8_DTYPE
]
out_shape
=
x
.
shape
out
=
torch
.
empty
(
x
.
shape
,
dtype
=
_FP8_DTYPE
,
device
=
x
.
device
)
if
scale
is
None
:
scale
=
torch
.
empty
((
*
out_shape
[:
-
1
],
1
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
dynamic_per_token_scaled_quant
(
out
,
x
,
scale
,
scale_ub
=
None
,
shuffle_scale
=
False
,
num_rows
=
None
,
num_rows_factor
=
1
,
)
return
out
,
scale
def
_rocm_aiter_per_token_quant_fake
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
out_shape
=
x
.
shape
return
(
torch
.
empty
(
x
.
shape
,
dtype
=
_FP8_DTYPE
,
device
=
x
.
device
),
torch
.
empty
((
*
out_shape
[:
-
1
],
1
),
dtype
=
torch
.
float32
,
device
=
x
.
device
),
)
def
_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl
(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter.ops.triton.fused_fp8_quant
import
fused_rms_fp8_group_quant
(
x_quant
,
x_quant_scales
),
_
,
_
,
res
=
fused_rms_fp8_group_quant
(
x
,
weight
,
variance_epsilon
,
None
,
None
,
None
,
group_size
=
group_size
,
dtype_quant
=
AITER_FP8_DTYPE
,
res1
=
residual
,
)
return
(
x_quant
,
x_quant_scales
,
res
)
def
_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake
(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
M
,
N
=
x
.
shape
scale_shape
=
(
M
,
(
N
+
group_size
-
1
)
//
group_size
)
return
(
torch
.
empty_like
(
x
,
dtype
=
AITER_FP8_DTYPE
,
device
=
x
.
device
),
torch
.
empty
(
scale_shape
,
dtype
=
torch
.
float32
,
device
=
x
.
device
),
torch
.
empty_like
(
residual
,
device
=
residual
.
device
),
)
def
_rocm_aiter_rmsnorm_fp8_group_quant_impl
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter.ops.triton.fused_fp8_quant
import
fused_rms_fp8_group_quant
(
x_quant
,
x_quant_scales
),
_
,
_
,
res
=
fused_rms_fp8_group_quant
(
x
,
weight
,
variance_epsilon
,
None
,
None
,
None
,
group_size
=
group_size
,
dtype_quant
=
AITER_FP8_DTYPE
,
res1
=
None
,
)
return
(
x_quant
,
x_quant_scales
)
def
_rocm_aiter_rmsnorm_fp8_group_quant_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
M
,
N
=
x
.
shape
scale_shape
=
(
M
,
(
N
+
group_size
-
1
)
//
group_size
)
return
(
torch
.
empty_like
(
x
,
dtype
=
AITER_FP8_DTYPE
,
device
=
x
.
device
),
torch
.
empty
(
scale_shape
,
dtype
=
torch
.
float32
,
device
=
x
.
device
),
)
def
_rocm_aiter_group_fp8_quant_impl
(
x
:
torch
.
Tensor
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
shape
[
-
1
]
%
group_size
==
0
,
"Input shape must be divisible by group size"
from
aiter
import
QuantType
,
get_hip_quant
aiter_per1x128_quant
=
get_hip_quant
(
QuantType
.
per_1x128
)
return
aiter_per1x128_quant
(
x
.
contiguous
(),
quant_dtype
=
AITER_FP8_DTYPE
)
def
_rocm_aiter_group_fp8_quant_fake
(
x
:
torch
.
Tensor
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
M
,
N
=
x
.
shape
x_fp8
=
torch
.
empty
((
M
,
N
),
dtype
=
AITER_FP8_DTYPE
,
device
=
x
.
device
)
out_bs
=
torch
.
empty
(
(
M
,
(
N
+
group_size
-
1
)
//
group_size
,
),
dtype
=
torch
.
float32
,
device
=
x
.
device
,
)
return
x_fp8
,
out_bs
def
_rocm_aiter_act_mul_and_fp8_group_quant_impl
(
x
:
torch
.
Tensor
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter.ops.triton.activation
import
act_mul_and_fp8_group_quant
return
act_mul_and_fp8_group_quant
(
x
,
activation
=
"silu"
,
group_size
=
group_size
,
dtype_quant
=
AITER_FP8_DTYPE
,
)
def
_rocm_aiter_act_mul_and_fp8_group_quant_fake
(
x
:
torch
.
Tensor
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
M
,
N
=
x
.
shape
assert
N
%
2
==
0
N_half
=
N
//
2
x_fp8
=
torch
.
empty
((
M
,
N_half
),
dtype
=
AITER_FP8_DTYPE
,
device
=
x
.
device
)
out_bs
=
torch
.
empty
(
(
M
,
(
N_half
+
group_size
-
1
)
//
group_size
,
),
dtype
=
torch
.
float32
,
device
=
x
.
device
,
)
return
x_fp8
,
out_bs
# Global flag to ensure ops are registered only once
_OPS_REGISTERED
=
False
...
...
@@ -473,7 +672,7 @@ class rocm_aiter_ops:
@
if_aiter_supported
def
is_linear_fp8_enaled
(
cls
)
->
bool
:
""" "Verifies device specs and availability of env variable."""
return
cls
.
is_linear_enabled
()
and
current_platform
.
is_fp8_fnuz
()
return
cls
.
is_linear_enabled
()
@
classmethod
@
if_aiter_supported
...
...
@@ -548,14 +747,6 @@ class rocm_aiter_ops:
)
# register all the custom ops here
direct_register_custom_op
(
op_name
=
"rocm_aiter_group_fp8_quant"
,
op_func
=
_rocm_aiter_group_fp8_quant_impl
,
mutates_args
=
[],
fake_impl
=
_rocm_aiter_group_fp8_quant_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_asm_moe_tkw1"
,
op_func
=
_rocm_aiter_asm_moe_tkw1_impl
,
...
...
@@ -615,27 +806,62 @@ class rocm_aiter_ops:
direct_register_custom_op
(
op_name
=
"rocm_aiter_gemm_a8w8_blockscale"
,
op_func
=
_rocm_aiter_gemm_a8w8_blockscale_impl
,
mutates_args
=
[],
fake_impl
=
_rocm_aiter_gemm_a8w8_blockscale_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_rms_norm"
,
op_func
=
_rocm_aiter_rms_norm_impl
,
mutates_args
=
[],
fake_impl
=
_rocm_aiter_rms_norm_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_rmsnorm2d_fwd_with_add"
,
op_func
=
_rocm_aiter_rmsnorm2d_fwd_with_add_impl
,
mutates_args
=
[],
fake_impl
=
_rocm_aiter_rmsnorm2d_fwd_with_add_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_rmsnorm_fp8_group_quant"
,
op_func
=
_rocm_aiter_rmsnorm_fp8_group_quant_impl
,
fake_impl
=
_rocm_aiter_rmsnorm_fp8_group_quant_fake
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_rmsnorm_with_add_fp8_group_quant"
,
op_func
=
_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl
,
fake_impl
=
_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_act_mul_and_fp8_group_quant"
,
op_func
=
_rocm_aiter_act_mul_and_fp8_group_quant_impl
,
fake_impl
=
_rocm_aiter_act_mul_and_fp8_group_quant_fake
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_group_fp8_quant"
,
op_func
=
_rocm_aiter_group_fp8_quant_impl
,
fake_impl
=
_rocm_aiter_group_fp8_quant_fake
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_per_tensor_quant"
,
op_func
=
_rocm_aiter_per_tensor_quant_impl
,
mutates_args
=
[],
fake_impl
=
_rocm_aiter_per_tensor_quant_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_per_token_quant"
,
op_func
=
_rocm_aiter_per_token_quant_impl
,
mutates_args
=
[
"scale"
],
fake_impl
=
_rocm_aiter_per_token_quant_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
_OPS_REGISTERED
=
True
@
staticmethod
...
...
@@ -830,6 +1056,22 @@ class rocm_aiter_ops:
kv_scale
=
kv_scale
,
)
@
staticmethod
def
per_tensor_quant
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
vllm
.
rocm_aiter_per_tensor_quant
(
x
,
quant_dtype
,
scale
)
@
staticmethod
def
per_token_quant
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
vllm
.
rocm_aiter_per_token_quant
(
x
,
quant_dtype
,
scale
)
@
staticmethod
def
triton_fp4_gemm_dynamic_qaunt
(
x
:
torch
.
Tensor
,
...
...
vllm/_custom_ops.py
View file @
8d75f22e
...
...
@@ -441,6 +441,46 @@ def rms_norm_dynamic_per_token_quant(
return
output
,
scales
# fused quant layer norm ops blocked
def
rms_norm_per_block_quant
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
group_size
:
list
[
int
],
scale_ub
:
torch
.
Tensor
|
None
=
None
,
residual
:
torch
.
Tensor
|
None
=
None
,
is_scale_transposed
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
len
(
group_size
)
==
2
output
=
torch
.
empty_like
(
input
,
dtype
=
quant_dtype
)
if
is_scale_transposed
:
scales
=
torch
.
empty
(
(
input
.
shape
[
-
1
]
//
group_size
[
1
],
input
.
numel
()
//
input
.
shape
[
-
1
]),
device
=
input
.
device
,
dtype
=
torch
.
float32
,
).
transpose
(
0
,
1
)
else
:
scales
=
torch
.
empty
(
(
input
.
numel
()
//
input
.
shape
[
-
1
],
input
.
shape
[
-
1
]
//
group_size
[
1
]),
device
=
input
.
device
,
dtype
=
torch
.
float32
,
)
torch
.
ops
.
_C
.
rms_norm_per_block_quant
(
output
,
input
,
weight
,
scales
,
epsilon
,
scale_ub
,
residual
,
group_size
[
1
],
is_scale_transposed
,
)
return
output
,
scales
# quantization ops
# awq
def
awq_dequantize
(
...
...
@@ -660,6 +700,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def
cutlass_encode_and_reorder_int4b_fake
(
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
b
,
memory_format
=
torch
.
contiguous_format
)
@
register_fake
(
"_C::cutlass_encode_and_reorder_int4b_grouped"
)
def
cutlass_encode_and_reorder_int4b_grouped_fake
(
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
b
,
memory_format
=
torch
.
contiguous_format
)
if
hasattr
(
torch
.
ops
.
_C
,
"allspark_w8a16_gemm"
):
...
...
@@ -1023,6 +1067,7 @@ def get_cutlass_moe_mm_problem_sizes(
n
:
int
,
k
:
int
,
blockscale_offsets
:
torch
.
Tensor
|
None
=
None
,
force_swap_ab
:
bool
|
None
=
None
,
):
"""
Compute only the per-expert problem sizes needed by the two grouped matrix
...
...
@@ -1032,9 +1077,20 @@ def get_cutlass_moe_mm_problem_sizes(
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
multiplication for the two grouped MMs
used in the fused MoE operation.
Optional:
- force_swap_ab: If set to True or False, explicitly enable or disable the
A/B input swap optimization. If None (default), the swap
is selected automatically based on tensor sizes.
"""
return
torch
.
ops
.
_C
.
get_cutlass_moe_mm_problem_sizes
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
num_experts
,
n
,
k
,
blockscale_offsets
topk_ids
,
problem_sizes1
,
problem_sizes2
,
num_experts
,
n
,
k
,
blockscale_offsets
,
force_swap_ab
,
)
...
...
@@ -1422,6 +1478,78 @@ def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor:
return
torch
.
ops
.
_C
.
cutlass_encode_and_reorder_int4b
(
b
)
def
cutlass_w4a8_moe_mm
(
out_tensors
:
torch
.
Tensor
,
a_tensors
:
torch
.
Tensor
,
b_tensors
:
torch
.
Tensor
,
a_scales
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_group_scales
:
torch
.
Tensor
,
b_group_size
:
int
,
expert_offsets
:
torch
.
Tensor
,
problem_sizes
:
torch
.
Tensor
,
a_strides
:
torch
.
Tensor
,
b_strides
:
torch
.
Tensor
,
c_strides
:
torch
.
Tensor
,
group_scale_strides
:
torch
.
Tensor
,
maybe_schedule
:
str
|
None
=
None
,
):
"""
Executes the CUTLASS-based fused-MoE grouped matrix multiplication for the
W4A8 quantization scheme. Uses group-wise quantization (INT4 -> FP8)
and both per-channel + per-token scaling in the epilogue.
Args:
out_tensors:
Output buffer for all experts (updated in-place).
a_tensors:
FP8 (E4M3FN) activations for all experts.
b_tensors:
INT4-packed weight matrix for all experts, packed to INT32
a_scales:
Per-token FP8 activation scales, applied in the epilogue.
b_scales:
Per-channel FP8 weight scales for each expert, applied in the epilogue.
b_group_scales:
FP8 scale values for group-wise INT4 weight blocks.
b_group_size:
Number of elements grouped under each entry of b_group_scales.
expert_offsets:
Cumulative token offsets
problem_sizes:
Per-expert (M, N, K) GEMM sizes used by the grouped GEMM launcher.
a/b/c/group_scale_strides:
Strides describing the memory layout of the input tensors.
maybe_schedule:
Optional override to choose a specific kernel or epilogue schedule.
Returns:
out_tensors updated in-place with the dequantized INT4xFP8 grouped GEMM result.
"""
return
torch
.
ops
.
_C
.
cutlass_w4a8_moe_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
,
maybe_schedule
,
)
def
cutlass_encode_and_reorder_int4b_grouped
(
b_tensors
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
_C
.
cutlass_encode_and_reorder_int4b_grouped
(
b_tensors
)
if
hasattr
(
torch
.
ops
.
_C
,
"permute_cols"
):
@
register_fake
(
"_C::permute_cols"
)
...
...
@@ -1603,7 +1731,7 @@ def scaled_fp8_quant(
output
,
input
,
scale
,
scale_ub
)
else
:
scale
=
torch
.
empty
(
(
1
,
1
)
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
scale
=
torch
.
empty
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
assert
scale
.
numel
()
==
1
,
f
"
{
scale
.
shape
}
"
...
...
@@ -1882,6 +2010,7 @@ def moe_align_block_size(
sorted_token_ids
:
torch
.
Tensor
,
experts_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
torch
.
ops
.
_moe_C
.
moe_align_block_size
(
topk_ids
,
...
...
@@ -1890,6 +2019,7 @@ def moe_align_block_size(
sorted_token_ids
,
experts_ids
,
num_tokens_post_pad
,
expert_map
,
)
...
...
@@ -1924,6 +2054,7 @@ def moe_lora_align_block_size(
num_tokens_post_pad
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
torch
.
ops
.
_moe_C
.
moe_lora_align_block_size
(
topk_ids
,
...
...
@@ -1938,6 +2069,7 @@ def moe_lora_align_block_size(
num_tokens_post_pad
,
adapter_enabled
,
lora_ids
,
expert_map
,
)
...
...
vllm/attention/backends/abstract.py
View file @
8d75f22e
...
...
@@ -166,6 +166,10 @@ class AttentionBackend(ABC):
def
supports_sink
(
cls
)
->
bool
:
return
False
@
classmethod
def
supports_mm_prefix
(
cls
)
->
bool
:
return
False
@
classmethod
def
is_sparse
(
cls
)
->
bool
:
return
False
...
...
@@ -207,6 +211,7 @@ class AttentionBackend(ABC):
use_mla
:
bool
,
has_sink
:
bool
,
use_sparse
:
bool
,
use_mm_prefix
:
bool
,
device_capability
:
"DeviceCapability"
,
attn_type
:
str
,
)
->
list
[
str
]:
...
...
@@ -219,6 +224,10 @@ class AttentionBackend(ABC):
invalid_reasons
.
append
(
"kv_cache_dtype not supported"
)
if
not
cls
.
supports_block_size
(
block_size
):
invalid_reasons
.
append
(
"block_size not supported"
)
if
use_mm_prefix
and
not
cls
.
supports_mm_prefix
():
invalid_reasons
.
append
(
"partial multimodal token full attention not supported"
)
if
use_mla
!=
cls
.
is_mla
():
if
use_mla
:
invalid_reasons
.
append
(
"MLA not supported"
)
...
...
@@ -289,6 +298,16 @@ class AttentionImpl(ABC, Generic[T]):
# even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode
:
bool
=
False
# Whether this attention implementation supports pre-quantized query input.
# When True, the attention layer will quantize queries before passing them
# to this backend, allowing torch.compile to fuse the quantization with
# previous operations. This is typically supported when using FP8 KV cache
# with compatible attention kernels (e.g., TRT-LLM).
# Subclasses should set this in __init__.
# TODO add support to more backends:
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input
:
bool
=
False
dcp_world_size
:
int
dcp_rank
:
int
...
...
@@ -368,22 +387,6 @@ class AttentionImpl(ABC, Generic[T]):
"""
return
False
def
supports_quant_query_input
(
self
)
->
bool
:
"""
Check if this attention implementation supports pre-quantized query input.
When True, the attention layer will quantize queries before passing them
to this backend, allowing torch.compile to fuse the quantization with
previous operations. This is typically supported when using FP8 KV cache
with compatible attention kernels (e.g., TRT-LLM).
TODO add support to more backends:
https://github.com/vllm-project/vllm/issues/25584
Returns:
bool: True if the implementation can accept pre-quantized queries.
"""
return
False
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
pass
...
...
vllm/attention/layer.py
View file @
8d75f22e
...
...
@@ -25,6 +25,7 @@ from vllm.config.vllm import VllmConfig
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.batch_invariant
import
vllm_is_batch_invariant
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
UnquantizedLinearMethod
,
...
...
@@ -88,7 +89,10 @@ def maybe_get_vit_flash_attn_backend(
if
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
:
from
aiter
import
flash_attn_varlen_func
else
:
from
vllm.attention.utils.fa_utils
import
flash_attn_varlen_func
try
:
from
vllm.attention.utils.fa_utils
import
flash_attn_varlen_func
except
ImportError
:
flash_attn_varlen_func
=
None
else
:
flash_attn_varlen_func
=
None
...
...
@@ -230,6 +234,10 @@ class Attention(nn.Module, AttentionLayerBase):
self
.
sliding_window
=
sliding_window
self
.
has_sink
=
extra_impl_args
.
get
(
"sinks"
)
is
not
None
# NOTE: model_config may be None during certain tests
model_config
=
vllm_config
.
model_config
self
.
use_mm_prefix
=
model_config
is
not
None
and
model_config
.
is_mm_prefix_lm
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype
=
torch
.
get_default_dtype
()
...
...
@@ -241,11 +249,30 @@ class Attention(nn.Module, AttentionLayerBase):
block_size
,
use_mla
=
False
,
has_sink
=
self
.
has_sink
,
use_mm_prefix
=
self
.
use_mm_prefix
,
attn_type
=
attn_type
,
)
else
:
self
.
attn_backend
=
attn_backend
# prefix caching + batch invariance is currently not supported for
# FLASHINFER and TRITON_MLA.
if
(
cache_config
is
not
None
and
cache_config
.
enable_prefix_caching
and
vllm_is_batch_invariant
()
and
(
self
.
attn_backend
.
get_name
()
==
"FLASHINFER"
or
self
.
attn_backend
.
get_name
()
==
"TRITON_MLA"
)
):
logger
.
warning_once
(
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
"with batch invariance, as it is not yet supported."
,
scope
=
"local"
,
)
cache_config
.
enable_prefix_caching
=
False
impl_cls
=
self
.
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
num_heads
,
...
...
@@ -303,7 +330,7 @@ class Attention(nn.Module, AttentionLayerBase):
self
.
query_quant
=
None
if
(
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
and
self
.
impl
.
supports_quant_query_input
()
and
self
.
impl
.
supports_quant_query_input
):
self
.
query_quant
=
QuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
)
...
...
@@ -338,7 +365,7 @@ class Attention(nn.Module, AttentionLayerBase):
assert
self
.
kv_cache_dtype
in
{
"fp8"
,
"fp8_e4m3"
}
# check if query quantization is supported
if
self
.
impl
.
supports_quant_query_input
()
:
if
self
.
impl
.
supports_quant_query_input
:
query
,
_
=
self
.
query_quant
(
query
,
self
.
_q_scale
)
if
self
.
use_output
:
...
...
@@ -623,6 +650,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
use_mla
=
True
,
use_sparse
=
use_sparse
,
)
if
(
cache_config
is
not
None
and
cache_config
.
enable_prefix_caching
and
vllm_is_batch_invariant
()
and
(
self
.
attn_backend
.
get_name
()
==
"TRITON_MLA"
or
self
.
attn_backend
.
get_name
()
==
"FLASHINFER"
)
):
logger
.
warning_once
(
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
"with batch invariance, as it is not yet supported."
,
scope
=
"local"
,
)
cache_config
.
enable_prefix_caching
=
False
impl_cls
=
cast
(
type
[
MLAAttentionImpl
],
self
.
attn_backend
.
get_impl_cls
())
self
.
impl
=
impl_cls
(
num_heads
=
self
.
num_heads
,
...
...
vllm/attention/layers/cross_attention.py
View file @
8d75f22e
...
...
@@ -103,7 +103,7 @@ def create_cross_attention_backend(
# needed here to know how many tokens to attend to from the cached
# cross-attention KV cache.
new_metadata
.
seq_lens
=
common_attn_metadata
.
encoder_seq_lens
new_metadata
.
seq_lens_cpu
=
torch
.
from_numpy
(
new_metadata
.
_
seq_lens_cpu
=
torch
.
from_numpy
(
common_attn_metadata
.
encoder_seq_lens_cpu
)
...
...
Prev
1
…
13
14
15
16
17
18
19
20
21
…
33
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment