Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d75f22e
Commit
8d75f22e
authored
Dec 13, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori
parents
ce888aa4
7d80c73d
Changes
656
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
835 additions
and
237 deletions
+835
-237
tests/v1/logits_processors/test_custom_offline.py
tests/v1/logits_processors/test_custom_offline.py
+0
-5
tests/v1/logits_processors/test_custom_online.py
tests/v1/logits_processors/test_custom_online.py
+2
-10
tests/v1/logits_processors/utils.py
tests/v1/logits_processors/utils.py
+1
-1
tests/v1/metrics/test_stats.py
tests/v1/metrics/test_stats.py
+102
-1
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+6
-2
tests/v1/spec_decode/test_max_len.py
tests/v1/spec_decode/test_max_len.py
+2
-2
tests/v1/spec_decode/test_speculators_eagle3.py
tests/v1/spec_decode/test_speculators_eagle3.py
+5
-0
tests/v1/spec_decode/test_tree_attention.py
tests/v1/spec_decode/test_tree_attention.py
+2
-2
tests/v1/structured_output/test_backend_guidance.py
tests/v1/structured_output/test_backend_guidance.py
+74
-0
tests/v1/structured_output/test_reasoning_structured_output.py
.../v1/structured_output/test_reasoning_structured_output.py
+1
-0
tests/v1/test_serial_utils.py
tests/v1/test_serial_utils.py
+15
-6
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+90
-86
tests/v1/worker/test_gpu_profiler.py
tests/v1/worker/test_gpu_profiler.py
+36
-35
tools/ep_kernels/install_python_libraries.sh
tools/ep_kernels/install_python_libraries.sh
+6
-12
use_existing_torch.py
use_existing_torch.py
+2
-5
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+290
-48
vllm/_custom_ops.py
vllm/_custom_ops.py
+134
-2
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+19
-16
vllm/attention/layer.py
vllm/attention/layer.py
+47
-3
vllm/attention/layers/cross_attention.py
vllm/attention/layers/cross_attention.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
656 of 656+
files are displayed.
Plain diff
Email patch
tests/v1/logits_processors/test_custom_offline.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
random
import
sys
from
typing
import
Any
from
typing
import
Any
import
pytest
import
pytest
...
@@ -10,7 +9,6 @@ from tests.utils import create_new_process_for_each_test
...
@@ -10,7 +9,6 @@ from tests.utils import create_new_process_for_each_test
from
tests.v1.logits_processors.utils
import
(
from
tests.v1.logits_processors.utils
import
(
DUMMY_LOGITPROC_ARG
,
DUMMY_LOGITPROC_ARG
,
DUMMY_LOGITPROC_FQCN
,
DUMMY_LOGITPROC_FQCN
,
DUMMY_LOGITPROC_MODULE
,
MAX_TOKENS
,
MAX_TOKENS
,
MODEL_NAME
,
MODEL_NAME
,
POOLING_MODEL_NAME
,
POOLING_MODEL_NAME
,
...
@@ -18,7 +16,6 @@ from tests.v1.logits_processors.utils import (
...
@@ -18,7 +16,6 @@ from tests.v1.logits_processors.utils import (
CustomLogitprocSource
,
CustomLogitprocSource
,
DummyLogitsProcessor
,
DummyLogitsProcessor
,
WrappedPerReqLogitsProcessor
,
WrappedPerReqLogitsProcessor
,
dummy_module
,
prompts
,
prompts
,
)
)
from
tests.v1.logits_processors.utils
import
entry_points
as
fake_entry_points
from
tests.v1.logits_processors.utils
import
entry_points
as
fake_entry_points
...
@@ -162,8 +159,6 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource
...
@@ -162,8 +159,6 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource
kwargs
:
dict
[
str
,
list
[
str
|
type
[
LogitsProcessor
]]]
=
{}
kwargs
:
dict
[
str
,
list
[
str
|
type
[
LogitsProcessor
]]]
=
{}
if
logitproc_source
==
CustomLogitprocSource
.
LOGITPROC_SOURCE_FQCN
:
if
logitproc_source
==
CustomLogitprocSource
.
LOGITPROC_SOURCE_FQCN
:
# Scenario: load logitproc based on fully-qualified class name (FQCN)
# Scenario: load logitproc based on fully-qualified class name (FQCN)
# Inject dummy module which defines logitproc
sys
.
modules
[
DUMMY_LOGITPROC_MODULE
]
=
dummy_module
kwargs
[
"logits_processors"
]
=
[
DUMMY_LOGITPROC_FQCN
]
kwargs
[
"logits_processors"
]
=
[
DUMMY_LOGITPROC_FQCN
]
elif
logitproc_source
==
CustomLogitprocSource
.
LOGITPROC_SOURCE_CLASS
:
elif
logitproc_source
==
CustomLogitprocSource
.
LOGITPROC_SOURCE_CLASS
:
# Scenario: load logitproc from provided class object
# Scenario: load logitproc from provided class object
...
...
tests/v1/logits_processors/test_custom_online.py
View file @
8d75f22e
...
@@ -14,11 +14,9 @@ from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_te
...
@@ -14,11 +14,9 @@ from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_te
from
tests.v1.logits_processors.utils
import
(
from
tests.v1.logits_processors.utils
import
(
DUMMY_LOGITPROC_ARG
,
DUMMY_LOGITPROC_ARG
,
DUMMY_LOGITPROC_FQCN
,
DUMMY_LOGITPROC_FQCN
,
DUMMY_LOGITPROC_MODULE
,
MAX_TOKENS
,
MAX_TOKENS
,
MODEL_NAME
,
MODEL_NAME
,
TEMP_GREEDY
,
TEMP_GREEDY
,
dummy_module
,
prompts
,
prompts
,
)
)
from
tests.v1.logits_processors.utils
import
entry_points
as
fake_entry_points
from
tests.v1.logits_processors.utils
import
entry_points
as
fake_entry_points
...
@@ -47,20 +45,14 @@ def _server_with_logitproc_entrypoint(
...
@@ -47,20 +45,14 @@ def _server_with_logitproc_entrypoint(
main
.
main
()
main
.
main
()
def
_server_with_logitproc_
module
(
def
_server_with_logitproc_
fqcn
(
env_dict
:
dict
[
str
,
str
]
|
None
,
env_dict
:
dict
[
str
,
str
]
|
None
,
model
:
str
,
model
:
str
,
vllm_serve_args
:
list
[
str
],
vllm_serve_args
:
list
[
str
],
)
->
None
:
)
->
None
:
"""Start vLLM server, inject module with dummy logitproc"""
"""Start vLLM server, inject module with dummy logitproc"""
# Patch `modules` to inject dummy logitproc module
from
vllm.entrypoints.cli
import
main
from
vllm.entrypoints.cli
import
main
sys
.
modules
[
DUMMY_LOGITPROC_MODULE
]
=
dummy_module
# fork is required for workers to see entrypoint patch
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"fork"
if
env_dict
is
not
None
:
if
env_dict
is
not
None
:
os
.
environ
.
update
(
env_dict
)
os
.
environ
.
update
(
env_dict
)
...
@@ -99,7 +91,7 @@ def server(default_server_args, request, monkeypatch):
...
@@ -99,7 +91,7 @@ def server(default_server_args, request, monkeypatch):
if
request
.
param
:
if
request
.
param
:
# Launch server, append FQCN argument, inject dummy logitproc module
# Launch server, append FQCN argument, inject dummy logitproc module
args
=
default_server_args
+
request
.
param
args
=
default_server_args
+
request
.
param
_server_fxn
=
_server_with_logitproc_
module
_server_fxn
=
_server_with_logitproc_
fqcn
else
:
else
:
# Launch server, inject dummy logitproc entrypoint
# Launch server, inject dummy logitproc entrypoint
args
=
default_server_args
args
=
default_server_args
...
...
tests/v1/logits_processors/utils.py
View file @
8d75f22e
...
@@ -27,7 +27,7 @@ DUMMY_LOGITPROC_ARG = "target_token"
...
@@ -27,7 +27,7 @@ DUMMY_LOGITPROC_ARG = "target_token"
TEMP_GREEDY
=
0.0
TEMP_GREEDY
=
0.0
MAX_TOKENS
=
20
MAX_TOKENS
=
20
DUMMY_LOGITPROC_ENTRYPOINT
=
"dummy_logitproc"
DUMMY_LOGITPROC_ENTRYPOINT
=
"dummy_logitproc"
DUMMY_LOGITPROC_MODULE
=
"
DummyModule
"
DUMMY_LOGITPROC_MODULE
=
"
tests.v1.logits_processors.utils
"
DUMMY_LOGITPROC_FQCN
=
f
"
{
DUMMY_LOGITPROC_MODULE
}
:DummyLogitsProcessor"
DUMMY_LOGITPROC_FQCN
=
f
"
{
DUMMY_LOGITPROC_MODULE
}
:DummyLogitsProcessor"
...
...
tests/v1/metrics/test_stats.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.v1.metrics.stats
import
IterationStats
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.metrics.stats
import
IterationStats
,
RequestStateStats
def
test_iteration_stats_repr
():
def
test_iteration_stats_repr
():
iteration_stats
=
IterationStats
()
iteration_stats
=
IterationStats
()
assert
repr
(
iteration_stats
).
startswith
(
"IterationStats("
)
assert
repr
(
iteration_stats
).
startswith
(
"IterationStats("
)
def
test_prefill_kv_computed_with_cache
():
"""Test that prefill KV compute correctly excludes cached tokens."""
iteration_stats
=
IterationStats
()
req_stats
=
RequestStateStats
(
arrival_time
=
0.0
)
req_stats
.
scheduled_ts
=
0.1
req_stats
.
first_token_ts
=
0.5
req_stats
.
last_token_ts
=
5.0
req_stats
.
num_generation_tokens
=
50
# Case 1: With prefix cache (1200 tokens cached)
iteration_stats
.
update_from_finished_request
(
finish_reason
=
FinishReason
.
STOP
,
num_prompt_tokens
=
10000
,
max_tokens_param
=
100
,
req_stats
=
req_stats
,
num_cached_tokens
=
1200
,
)
finished_req
=
iteration_stats
.
finished_requests
[
0
]
assert
finished_req
.
num_prompt_tokens
==
10000
assert
finished_req
.
num_cached_tokens
==
1200
# Verify calculation: prefill KV = prompt tokens - cached tokens
prefill_kv_computed
=
finished_req
.
num_prompt_tokens
-
max
(
finished_req
.
num_cached_tokens
,
0
)
assert
prefill_kv_computed
==
8800
# 10000 - 1200
def
test_prefill_kv_computed_no_cache
():
"""Test prefill KV compute without prefix caching."""
iteration_stats
=
IterationStats
()
req_stats
=
RequestStateStats
(
arrival_time
=
0.0
)
req_stats
.
scheduled_ts
=
0.1
req_stats
.
first_token_ts
=
0.5
req_stats
.
last_token_ts
=
2.0
req_stats
.
num_generation_tokens
=
10
# Case 2: No prefix cache
iteration_stats
.
update_from_finished_request
(
finish_reason
=
FinishReason
.
STOP
,
num_prompt_tokens
=
2000
,
max_tokens_param
=
100
,
req_stats
=
req_stats
,
num_cached_tokens
=
0
,
)
finished_req
=
iteration_stats
.
finished_requests
[
0
]
assert
finished_req
.
num_prompt_tokens
==
2000
assert
finished_req
.
num_cached_tokens
==
0
# Verify calculation: prefill KV = full prompt when no cache
prefill_kv_computed
=
finished_req
.
num_prompt_tokens
-
max
(
finished_req
.
num_cached_tokens
,
0
)
assert
prefill_kv_computed
==
2000
def
test_prefill_kv_computed_edge_cases
():
"""Test edge cases for prefill KV compute calculation."""
iteration_stats
=
IterationStats
()
req_stats
=
RequestStateStats
(
arrival_time
=
0.0
)
req_stats
.
scheduled_ts
=
0.1
req_stats
.
first_token_ts
=
0.5
req_stats
.
last_token_ts
=
1.0
req_stats
.
num_generation_tokens
=
1
# Case 3: Negative num_cached_tokens (shouldn't happen, but handle gracefully)
iteration_stats
.
update_from_finished_request
(
finish_reason
=
FinishReason
.
STOP
,
num_prompt_tokens
=
100
,
max_tokens_param
=
10
,
req_stats
=
req_stats
,
num_cached_tokens
=-
1
,
)
finished_req
=
iteration_stats
.
finished_requests
[
0
]
# max() should handle negative values
prefill_kv_computed
=
finished_req
.
num_prompt_tokens
-
max
(
finished_req
.
num_cached_tokens
,
0
)
assert
prefill_kv_computed
==
100
# Should treat negative as 0
# Case 4: All tokens cached (shouldn't happen in practice)
iteration_stats2
=
IterationStats
()
iteration_stats2
.
update_from_finished_request
(
finish_reason
=
FinishReason
.
STOP
,
num_prompt_tokens
=
100
,
max_tokens_param
=
10
,
req_stats
=
req_stats
,
num_cached_tokens
=
100
,
)
finished_req2
=
iteration_stats2
.
finished_requests
[
0
]
prefill_kv_computed2
=
finished_req2
.
num_prompt_tokens
-
max
(
finished_req2
.
num_cached_tokens
,
0
)
assert
prefill_kv_computed2
==
0
# All cached, nothing computed
tests/v1/spec_decode/test_eagle.py
View file @
8d75f22e
...
@@ -339,7 +339,7 @@ def test_load_model(
...
@@ -339,7 +339,7 @@ def test_load_model(
"multi-token eagle spec decode on current platform"
"multi-token eagle spec decode on current platform"
)
)
if
attn_backend
==
"
FLASH_ATTN
"
and
current_platform
.
is_rocm
():
if
attn_backend
==
"
ROCM_AITER_FA
"
and
current_platform
.
is_rocm
():
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
# Setup draft model mock
# Setup draft model mock
...
@@ -436,7 +436,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
...
@@ -436,7 +436,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
"because it requires special input mocking."
"because it requires special input mocking."
)
)
if
attn_backend
==
"
FLASH_ATTN
"
and
current_platform
.
is_rocm
():
if
attn_backend
==
"
ROCM_AITER_FA
"
and
current_platform
.
is_rocm
():
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
# Use GPU device
# Use GPU device
...
@@ -543,6 +543,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
...
@@ -543,6 +543,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
attn_metadata_builder_cls
,
_
=
try_get_attention_backend
(
attn_metadata_builder_cls
,
_
=
try_get_attention_backend
(
AttentionBackendEnum
.
TREE_ATTN
AttentionBackendEnum
.
TREE_ATTN
)
)
elif
attn_backend
==
"ROCM_AITER_FA"
:
attn_metadata_builder_cls
,
_
=
try_get_attention_backend
(
AttentionBackendEnum
.
ROCM_AITER_FA
)
else
:
else
:
raise
ValueError
(
f
"Unsupported attention backend:
{
attn_backend
}
"
)
raise
ValueError
(
f
"Unsupported attention backend:
{
attn_backend
}
"
)
...
...
tests/v1/spec_decode/test_max_len.py
View file @
8d75f22e
...
@@ -47,7 +47,7 @@ def test_eagle_max_len(
...
@@ -47,7 +47,7 @@ def test_eagle_max_len(
"multi-token eagle spec decode on current platform"
"multi-token eagle spec decode on current platform"
)
)
if
attn_backend
==
"
FLASH_ATTN
"
and
current_platform
.
is_rocm
():
if
attn_backend
==
"
ROCM_AITER_FA
"
and
current_platform
.
is_rocm
():
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
llm
=
LLM
(
llm
=
LLM
(
...
@@ -82,7 +82,7 @@ def test_eagle_max_len(
...
@@ -82,7 +82,7 @@ def test_eagle_max_len(
len
(
o
.
prompt_token_ids
)
len
(
o
.
prompt_token_ids
)
<
80
<
80
<
len
(
o
.
prompt_token_ids
)
+
len
(
o
.
outputs
[
0
].
token_ids
)
<
len
(
o
.
prompt_token_ids
)
+
len
(
o
.
outputs
[
0
].
token_ids
)
<
200
<
=
200
),
(
),
(
"This test is only meaningful if the output "
"This test is only meaningful if the output "
"is longer than the eagle max length"
"is longer than the eagle max length"
...
...
tests/v1/spec_decode/test_speculators_eagle3.py
View file @
8d75f22e
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
vllm.config
import
SpeculativeConfig
from
vllm.config
import
SpeculativeConfig
from
vllm.model_executor.models.interfaces
import
supports_eagle3
from
vllm.model_executor.models.interfaces
import
supports_eagle3
from
vllm.platforms
import
current_platform
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -21,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3
...
@@ -21,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3
pytest
.
param
(
pytest
.
param
(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16"
,
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16"
,
id
=
"qwen3-eagle3-speculator-w4a16-verifier"
,
id
=
"qwen3-eagle3-speculator-w4a16-verifier"
,
marks
=
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"The tests are skipped on rocm platform."
,
),
),
),
],
],
)
)
...
...
tests/v1/spec_decode/test_tree_attention.py
View file @
8d75f22e
...
@@ -88,8 +88,8 @@ def forward_attention(
...
@@ -88,8 +88,8 @@ def forward_attention(
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc
.
cpu
(),
query_start_loc_cpu
=
query_start_loc
.
cpu
(),
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
seq_lens_cpu
=
seq_lens
.
cpu
(),
_
seq_lens_cpu
=
seq_lens
.
cpu
(),
num_computed_tokens_cpu
=
context_lens
.
cpu
(),
_
num_computed_tokens_cpu
=
context_lens
.
cpu
(),
num_reqs
=
batch_size
,
num_reqs
=
batch_size
,
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
...
...
tests/v1/structured_output/test_backend_guidance.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
from
concurrent.futures
import
Future
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
vllm.config
import
StructuredOutputsConfig
,
VllmConfig
from
vllm.config
import
StructuredOutputsConfig
,
VllmConfig
from
vllm.config.model
import
ModelConfig
from
vllm.config.model
import
ModelConfig
from
vllm.config.parallel
import
ParallelConfig
from
vllm.config.speculative
import
SpeculativeConfig
from
vllm.config.speculative
import
SpeculativeConfig
from
vllm.sampling_params
import
SamplingParams
,
StructuredOutputsParams
from
vllm.sampling_params
import
SamplingParams
,
StructuredOutputsParams
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -116,3 +121,72 @@ def test_grammar_bitmask_with_specdec():
...
@@ -116,3 +121,72 @@ def test_grammar_bitmask_with_specdec():
)
# EOS not the final token
)
# EOS not the final token
grammar_bitmask
(
request
,
prompt
[
i
:])
# EOS not present
grammar_bitmask
(
request
,
prompt
[
i
:])
# EOS not present
grammar_bitmask
(
request
,
prompt
[
i
:]
+
[
tokenizer
.
eos_token_id
])
grammar_bitmask
(
request
,
prompt
[
i
:]
+
[
tokenizer
.
eos_token_id
])
@
pytest
.
mark
.
parametrize
(
"async_grammar"
,
[
True
,
False
])
def
test_grammar_init_async_and_sync
(
async_grammar
):
"""Test grammar initialization works correctly in both async and sync modes.
This test validates that the distributed_executor_backend config option
correctly controls whether grammar compilation happens asynchronously
(via executor.submit) or synchronously. When set to "external_launcher",
grammar compilation is synchronous to avoid deadlocks.
"""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
TOKENIZER
)
prompt
=
tokenizer
.
encode
(
'{"a": "b"}'
)
# Use "external_launcher" for sync mode, None for async mode
executor_backend
=
None
if
async_grammar
else
"external_launcher"
vllm_config
=
VllmConfig
(
model_config
=
ModelConfig
(
tokenizer
=
TOKENIZER
),
structured_outputs_config
=
StructuredOutputsConfig
(
backend
=
"guidance"
),
parallel_config
=
ParallelConfig
(
distributed_executor_backend
=
executor_backend
),
)
structured_output_manager
=
StructuredOutputManager
(
vllm_config
)
sampling_params
=
SamplingParams
(
structured_outputs
=
StructuredOutputsParams
(
json
=
'{"type": "object"}'
,
),
)
sampling_params
.
structured_outputs
.
_backend
=
"guidance"
request
=
Request
(
"test_request"
,
prompt_token_ids
=
prompt
,
sampling_params
=
sampling_params
,
pooling_params
=
None
,
eos_token_id
=
tokenizer
.
eos_token_id
,
)
structured_output_manager
.
grammar_init
(
request
)
# Check the internal _grammar type immediately after init
# Before _check_grammar_completion is called, async mode should have a Future
raw_grammar
=
request
.
structured_output_request
.
_grammar
if
async_grammar
:
assert
isinstance
(
raw_grammar
,
Future
),
(
"Async mode should store a Future before completion"
)
else
:
assert
not
isinstance
(
raw_grammar
,
Future
),
(
"Sync mode should store the grammar directly, not a Future"
)
# Wait for grammar to be ready (handles both async and sync cases)
start_time
=
time
.
time
()
while
not
request
.
structured_output_request
.
_check_grammar_completion
():
if
time
.
time
()
-
start_time
>
5
:
# 5-second timeout
pytest
.
fail
(
"Grammar compilation timed out"
)
time
.
sleep
(
0.01
)
# After completion, _grammar should no longer be a Future
assert
not
isinstance
(
request
.
structured_output_request
.
_grammar
,
Future
)
# Verify grammar is properly initialized and functional
grammar
=
request
.
structured_output_request
.
grammar
assert
grammar
is
not
None
assert
not
grammar
.
is_terminated
()
# Verify the grammar can accept valid tokens
assert
grammar
.
accept_tokens
(
request
.
request_id
,
prompt
)
tests/v1/structured_output/test_reasoning_structured_output.py
View file @
8d75f22e
...
@@ -70,6 +70,7 @@ class TestReasoningStructuredOutput:
...
@@ -70,6 +70,7 @@ class TestReasoningStructuredOutput:
request
.
use_structured_output
=
True
request
.
use_structured_output
=
True
request
.
prompt_token_ids
=
[
1
,
2
,
3
,
4
,
5
]
request
.
prompt_token_ids
=
[
1
,
2
,
3
,
4
,
5
]
request
.
all_token_ids
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
]
request
.
all_token_ids
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
]
request
.
num_computed_tokens
=
5
return
request
return
request
def
test_should_fill_bitmask_with_enable_in_reasoning
(
def
test_should_fill_bitmask_with_enable_in_reasoning
(
...
...
tests/v1/test_serial_utils.py
View file @
8d75f22e
...
@@ -104,22 +104,31 @@ class MyRequest(msgspec.Struct):
...
@@ -104,22 +104,31 @@ class MyRequest(msgspec.Struct):
def
test_multimodal_kwargs
():
def
test_multimodal_kwargs
():
e1
=
MultiModalFieldElem
(
e1
=
MultiModalFieldElem
(
"audio"
,
"a0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
bfloat16
),
MultiModalBatchedField
()
"audio"
,
"a0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
bfloat16
),
MultiModalBatchedField
(),
)
)
e2
=
MultiModalFieldElem
(
e2
=
MultiModalFieldElem
(
"video"
,
"video"
,
"v0"
,
"v0"
,
[
torch
.
zeros
(
1000
,
dtype
=
torch
.
int8
)
for
_
in
range
(
4
)],
[
torch
.
zeros
(
1000
,
dtype
=
torch
.
int8
)
for
_
in
range
(
4
)],
MultiModalFlatField
([[
slice
(
1
,
2
,
3
),
slice
(
4
,
5
,
6
)],
[
slice
(
None
,
2
)]],
0
),
MultiModalFlatField
(
slices
=
[[
slice
(
1
,
2
,
3
),
slice
(
4
,
5
,
6
)],
[
slice
(
None
,
2
)]],
dim
=
0
,
),
)
)
e3
=
MultiModalFieldElem
(
e3
=
MultiModalFieldElem
(
"image"
,
"i0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalSharedField
(
4
)
"image"
,
"i0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalSharedField
(
batch_size
=
4
),
)
)
e4
=
MultiModalFieldElem
(
e4
=
MultiModalFieldElem
(
"image"
,
"image"
,
"i1"
,
"i1"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalFlatField
([
slice
(
1
,
2
,
3
),
slice
(
4
,
5
,
6
)],
2
),
MultiModalFlatField
(
slices
=
[
slice
(
1
,
2
,
3
),
slice
(
4
,
5
,
6
)],
dim
=
2
),
)
)
audio
=
MultiModalKwargsItem
.
from_elems
([
e1
])
audio
=
MultiModalKwargsItem
.
from_elems
([
e1
])
video
=
MultiModalKwargsItem
.
from_elems
([
e2
])
video
=
MultiModalKwargsItem
.
from_elems
([
e2
])
...
@@ -138,8 +147,8 @@ def test_multimodal_kwargs():
...
@@ -138,8 +147,8 @@ def test_multimodal_kwargs():
total_len
=
sum
(
memoryview
(
x
).
cast
(
"B"
).
nbytes
for
x
in
encoded
)
total_len
=
sum
(
memoryview
(
x
).
cast
(
"B"
).
nbytes
for
x
in
encoded
)
# expected total encoding length, should be 143
06
, +-20 for minor changes
# expected total encoding length, should be 143
95
, +-20 for minor changes
assert
14
2
75
<=
total_len
<=
14
3
25
assert
14
3
75
<=
total_len
<=
14
4
25
decoded
=
decoder
.
decode
(
encoded
).
mm
[
0
]
decoded
=
decoder
.
decode
(
encoded
).
mm
[
0
]
assert
isinstance
(
decoded
,
MultiModalKwargsItems
)
assert
isinstance
(
decoded
,
MultiModalKwargsItems
)
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
8d75f22e
...
@@ -6,8 +6,10 @@ import pytest
...
@@ -6,8 +6,10 @@ import pytest
import
torch
import
torch
from
vllm.attention.backends.abstract
import
MultipleOf
from
vllm.attention.backends.abstract
import
MultipleOf
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
(
from
vllm.config
import
(
AttentionConfig
,
CacheConfig
,
CacheConfig
,
ModelConfig
,
ModelConfig
,
ParallelConfig
,
ParallelConfig
,
...
@@ -761,7 +763,11 @@ def test_init_kv_cache_with_kv_sharing_valid():
...
@@ -761,7 +763,11 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert
kv_cache_config_after_init
.
kv_cache_groups
[
0
].
layer_names
[
1
]
==
layer_1
assert
kv_cache_config_after_init
.
kv_cache_groups
[
0
].
layer_names
[
1
]
==
layer_1
def
test_hybrid_attention_mamba_tensor_shapes
(
monkeypatch
):
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Attention backend FLASHINFER is not supported on ROCm."
,
)
def
test_hybrid_attention_mamba_tensor_shapes
():
"""
"""
The GPU model runner creates different views into the
The GPU model runner creates different views into the
KVCacheTensors for the attention and mamba layers
KVCacheTensors for the attention and mamba layers
...
@@ -802,11 +808,13 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
...
@@ -802,11 +808,13 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
cache_dtype
=
"auto"
,
cache_dtype
=
"auto"
,
)
)
parallel_config
=
ParallelConfig
()
parallel_config
=
ParallelConfig
()
attention_config
=
AttentionConfig
(
backend
=
AttentionBackendEnum
.
FLASHINFER
)
vllm_config
=
VllmConfig
(
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
model_config
=
model_config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
scheduler_config
=
scheduler_config
,
scheduler_config
=
scheduler_config
,
parallel_config
=
parallel_config
,
parallel_config
=
parallel_config
,
attention_config
=
attention_config
,
)
)
layer_0
=
"model.layers.0.self_attn.attn"
layer_0
=
"model.layers.0.self_attn.attn"
...
@@ -816,8 +824,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
...
@@ -816,8 +824,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
layer_4
=
"model.layers.4.mixer"
layer_4
=
"model.layers.4.mixer"
layer_5
=
"model.layers.5.mixer"
layer_5
=
"model.layers.5.mixer"
with
set_current_vllm_config
(
vllm_config
),
monkeypatch
.
context
()
as
m
:
with
set_current_vllm_config
(
vllm_config
):
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLASHINFER"
)
hf_config
=
vllm_config
.
model_config
.
hf_config
hf_config
=
vllm_config
.
model_config
.
hf_config
fwd_context
=
{}
fwd_context
=
{}
for
key
in
[
layer_0
,
layer_1
]:
for
key
in
[
layer_0
,
layer_1
]:
...
@@ -847,10 +854,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
...
@@ -847,10 +854,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
)
)
# suppress var not used error
# suppress var not used error
assert
fwd_context
is
not
None
assert
fwd_context
is
not
None
vllm_ctx
=
vllm_config
.
compilation_config
.
static_forward_context
vllm_ctx
=
vllm_config
.
compilation_config
.
static_forward_context
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLASHINFER"
)
runner
=
GPUModelRunner
(
vllm_config
,
DEVICE
)
runner
=
GPUModelRunner
(
vllm_config
,
DEVICE
)
kv_cache_spec
=
runner
.
get_kv_cache_spec
()
kv_cache_spec
=
runner
.
get_kv_cache_spec
()
...
@@ -861,94 +865,94 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
...
@@ -861,94 +865,94 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
)[
0
]
)[
0
]
runner
.
initialize_kv_cache
(
kv_cache_config
)
runner
.
initialize_kv_cache
(
kv_cache_config
)
# random partition of blocks
# random partition of blocks
# blocks0 will be assigned to attention layers
# blocks0 will be assigned to attention layers
# blocks1 will be assigned to mamba layers
# blocks1 will be assigned to mamba layers
num_blocks
=
kv_cache_config
.
num_blocks
num_blocks
=
kv_cache_config
.
num_blocks
ind
=
np
.
arange
(
num_blocks
)
ind
=
np
.
arange
(
num_blocks
)
np
.
random
.
shuffle
(
ind
)
np
.
random
.
shuffle
(
ind
)
blocks0
,
blocks1
=
ind
[:
(
num_blocks
//
2
)],
ind
[(
num_blocks
//
2
)
:]
blocks0
,
blocks1
=
ind
[:
(
num_blocks
//
2
)],
ind
[(
num_blocks
//
2
)
:]
attn_shape
=
vllm_ctx
[
layer_0
].
kv_cache
[
0
].
shape
attn_shape
=
vllm_ctx
[
layer_0
].
kv_cache
[
0
].
shape
conv_shape
=
vllm_ctx
[
layer_2
].
kv_cache
[
0
][
0
].
shape
conv_shape
=
vllm_ctx
[
layer_2
].
kv_cache
[
0
][
0
].
shape
ssm_shape
=
vllm_ctx
[
layer_2
].
kv_cache
[
0
][
1
].
shape
ssm_shape
=
vllm_ctx
[
layer_2
].
kv_cache
[
0
][
1
].
shape
# assert we are using FlashInfer
# assert we are using FlashInfer
assert
attn_shape
[
0
]
%
num_blocks
==
0
assert
attn_shape
[
0
]
%
num_blocks
==
0
block_split_ratio
=
attn_shape
[
0
]
//
num_blocks
block_split_ratio
=
attn_shape
[
0
]
//
num_blocks
# use small blocks for testing to avoid memory issues
# use small blocks for testing to avoid memory issues
test_block_size
=
min
(
2
,
len
(
blocks0
),
len
(
blocks1
))
test_block_size
=
min
(
2
,
len
(
blocks0
),
len
(
blocks1
))
# use non-overlapping blocks to avoid data contamination
# use non-overlapping blocks to avoid data contamination
# Split kernel blocks: first half for attention, second half for mamba
# Split kernel blocks: first half for attention, second half for mamba
mid_point
=
num_blocks
//
2
mid_point
=
num_blocks
//
2
# attention uses kernel blocks from first half (mapped to logical blocks)
# attention uses kernel blocks from first half (mapped to logical blocks)
kv_blocks_for_attention
=
np
.
array
([
0
,
1
])[:
test_block_size
]
kv_blocks_for_attention
=
np
.
array
([
0
,
1
])[:
test_block_size
]
# mamba uses kernel blocks from second half
# mamba uses kernel blocks from second half
kv_blocks_for_mamba
=
np
.
array
([
mid_point
,
mid_point
+
1
])[:
test_block_size
]
kv_blocks_for_mamba
=
np
.
array
([
mid_point
,
mid_point
+
1
])[:
test_block_size
]
# create small constant tensors for testing with corrected shapes
# create small constant tensors for testing with corrected shapes
# attention: [block_size, ...] starting from dimension 2
# attention: [block_size, ...] starting from dimension 2
attn_constant_shape
=
attn_shape
[
2
:]
attn_constant_shape
=
attn_shape
[
2
:]
conv_constant_shape
=
conv_shape
[
1
:]
conv_constant_shape
=
conv_shape
[
1
:]
ssm_constant_shape
=
ssm_shape
[
1
:]
ssm_constant_shape
=
ssm_shape
[
1
:]
attn_blocks_constant
=
torch
.
full
(
attn_blocks_constant
=
torch
.
full
(
(
test_block_size
,
*
attn_constant_shape
),
device
=
DEVICE
,
fill_value
=
3.33
(
test_block_size
,
*
attn_constant_shape
),
device
=
DEVICE
,
fill_value
=
3.33
)
)
conv_blocks_constant
=
torch
.
full
(
conv_blocks_constant
=
torch
.
full
(
(
test_block_size
,
*
conv_constant_shape
),
device
=
DEVICE
,
fill_value
=
6.66
(
test_block_size
,
*
conv_constant_shape
),
device
=
DEVICE
,
fill_value
=
6.66
)
)
ssm_blocks_constant
=
torch
.
full
(
ssm_blocks_constant
=
torch
.
full
(
(
test_block_size
,
*
ssm_constant_shape
),
device
=
DEVICE
,
fill_value
=
9.99
(
test_block_size
,
*
ssm_constant_shape
),
device
=
DEVICE
,
fill_value
=
9.99
)
)
# Fill attention blocks with constants using kv block indices
# Fill attention blocks with constants using kv block indices
kernel_blocks_for_attention
=
kv_blocks_for_attention
*
block_split_ratio
kernel_blocks_for_attention
=
kv_blocks_for_attention
*
block_split_ratio
for
layer
in
[
layer_0
,
layer_1
]:
for
layer
in
[
layer_0
,
layer_1
]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
for
i
,
kernel_block
in
enumerate
(
kernel_blocks_for_attention
):
for
i
,
kernel_block
in
enumerate
(
kernel_blocks_for_attention
):
vllm_ctx
[
layer
].
kv_cache
[
0
][
kernel_block
,
:]
=
attn_blocks_constant
[
i
]
vllm_ctx
[
layer
].
kv_cache
[
0
][
kernel_block
,
:]
=
attn_blocks_constant
[
i
]
# fill mamba blocks with constants using kernel block indices
# fill mamba blocks with constants using kernel block indices
for
layer
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
for
layer
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
# mamba: kv_cache[0][component][kernel_block_idx, ...]
# mamba: kv_cache[0][component][kernel_block_idx, ...]
for
i
,
kv_block
in
enumerate
(
kv_blocks_for_mamba
):
for
i
,
kv_block
in
enumerate
(
kv_blocks_for_mamba
):
vllm_ctx
[
layer
].
kv_cache
[
0
][
0
][
kv_block
,
:]
=
conv_blocks_constant
[
i
]
vllm_ctx
[
layer
].
kv_cache
[
0
][
0
][
kv_block
,
:]
=
conv_blocks_constant
[
i
]
vllm_ctx
[
layer
].
kv_cache
[
0
][
1
][
kv_block
,
:]
=
ssm_blocks_constant
[
i
]
vllm_ctx
[
layer
].
kv_cache
[
0
][
1
][
kv_block
,
:]
=
ssm_blocks_constant
[
i
]
# verify attention and mamba contents are correct
# verify attention and mamba contents are correct
for
layer
in
[
layer_0
,
layer_1
]:
for
layer
in
[
layer_0
,
layer_1
]:
for
i
,
kernel_block
in
enumerate
(
kernel_blocks_for_attention
):
for
i
,
kernel_block
in
enumerate
(
kernel_blocks_for_attention
):
actual_kv
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
kernel_block
,
:]
actual_kv
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
kernel_block
,
:]
expected
=
attn_blocks_constant
[
i
]
expected
=
attn_blocks_constant
[
i
]
# Check K and V separately
# Check K and V separately
assert
torch
.
equal
(
actual_kv
[
0
],
expected
)
assert
torch
.
equal
(
actual_kv
[
0
],
expected
)
assert
torch
.
equal
(
actual_kv
[
1
],
expected
)
assert
torch
.
equal
(
actual_kv
[
1
],
expected
)
for
layer
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
for
layer
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
for
i
,
kv_block
in
enumerate
(
kv_blocks_for_mamba
):
for
i
,
kv_block
in
enumerate
(
kv_blocks_for_mamba
):
actual_conv
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
0
][
kv_block
,
:]
actual_conv
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
0
][
kv_block
,
:]
actual_ssm
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
1
][
kv_block
,
:]
actual_ssm
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
1
][
kv_block
,
:]
expected_conv
=
conv_blocks_constant
[
i
]
expected_conv
=
conv_blocks_constant
[
i
]
expected_ssm
=
ssm_blocks_constant
[
i
]
expected_ssm
=
ssm_blocks_constant
[
i
]
assert
torch
.
equal
(
actual_conv
,
expected_conv
)
assert
torch
.
equal
(
actual_conv
,
expected_conv
)
assert
torch
.
equal
(
actual_ssm
,
expected_ssm
)
assert
torch
.
equal
(
actual_ssm
,
expected_ssm
)
for
layer
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
for
layer
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
for
i
,
kv_block
in
enumerate
(
kv_blocks_for_mamba
):
for
i
,
kv_block
in
enumerate
(
kv_blocks_for_mamba
):
actual_conv
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
0
][
kv_block
,
:]
actual_conv
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
0
][
kv_block
,
:]
actual_ssm
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
1
][
kv_block
,
:]
actual_ssm
=
vllm_ctx
[
layer
].
kv_cache
[
0
][
1
][
kv_block
,
:]
expected_conv
=
conv_blocks_constant
[
i
]
expected_conv
=
conv_blocks_constant
[
i
]
expected_ssm
=
ssm_blocks_constant
[
i
]
expected_ssm
=
ssm_blocks_constant
[
i
]
assert
torch
.
equal
(
actual_conv
,
expected_conv
)
assert
torch
.
equal
(
actual_conv
,
expected_conv
)
assert
torch
.
equal
(
actual_ssm
,
expected_ssm
)
assert
torch
.
equal
(
actual_ssm
,
expected_ssm
)
def
test_hybrid_block_table_initialization
():
def
test_hybrid_block_table_initialization
():
...
...
tests/v1/worker/test_gpu_profiler.py
View file @
8d75f22e
...
@@ -2,8 +2,8 @@
...
@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
pytest
import
vllm.envs
as
envs
from
vllm.config
import
ProfilerConfig
from
vllm.profiler.
gpu_profil
er
import
WorkerProfiler
from
vllm.profiler.
wrapp
er
import
WorkerProfiler
class
ConcreteWorkerProfiler
(
WorkerProfiler
):
class
ConcreteWorkerProfiler
(
WorkerProfiler
):
...
@@ -11,11 +11,11 @@ class ConcreteWorkerProfiler(WorkerProfiler):
...
@@ -11,11 +11,11 @@ class ConcreteWorkerProfiler(WorkerProfiler):
A basic implementation of a worker profiler for testing purposes.
A basic implementation of a worker profiler for testing purposes.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
,
profiler_config
:
ProfilerConfig
):
self
.
start_call_count
=
0
self
.
start_call_count
=
0
self
.
stop_call_count
=
0
self
.
stop_call_count
=
0
self
.
should_fail_start
=
False
self
.
should_fail_start
=
False
super
().
__init__
()
super
().
__init__
(
profiler_config
)
def
_start
(
self
)
->
None
:
def
_start
(
self
)
->
None
:
if
self
.
should_fail_start
:
if
self
.
should_fail_start
:
...
@@ -26,17 +26,19 @@ class ConcreteWorkerProfiler(WorkerProfiler):
...
@@ -26,17 +26,19 @@ class ConcreteWorkerProfiler(WorkerProfiler):
self
.
stop_call_count
+=
1
self
.
stop_call_count
+=
1
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
def
reset_mocks
():
def
default_profiler_config
():
"""Fixture to reset mocks and env variables before each test."""
return
ProfilerConfig
(
envs
.
VLLM_PROFILER_DELAY_ITERS
=
0
profiler
=
"torch"
,
envs
.
VLLM_PROFILER_MAX_ITERS
=
0
torch_profiler_dir
=
"/tmp/mock"
,
delay_iterations
=
0
,
max_iterations
=
0
,
)
def
test_immediate_start_stop
():
def
test_immediate_start_stop
(
default_profiler_config
):
"""Test standard start without delay."""
"""Test standard start without delay."""
profiler
=
ConcreteWorkerProfiler
()
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
profiler
.
start
()
profiler
.
start
()
assert
profiler
.
_running
is
True
assert
profiler
.
_running
is
True
assert
profiler
.
_active
is
True
assert
profiler
.
_active
is
True
...
@@ -48,10 +50,10 @@ def test_immediate_start_stop():
...
@@ -48,10 +50,10 @@ def test_immediate_start_stop():
assert
profiler
.
stop_call_count
==
1
assert
profiler
.
stop_call_count
==
1
def
test_delayed_start
():
def
test_delayed_start
(
default_profiler_config
):
"""Test that profiler waits for N steps before actually starting."""
"""Test that profiler waits for N steps before actually starting."""
envs
.
VLLM_PROFILER_DELAY_ITERS
=
2
default_profiler_config
.
delay_iterations
=
2
profiler
=
ConcreteWorkerProfiler
()
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
# User requests start
# User requests start
profiler
.
start
()
profiler
.
start
()
...
@@ -71,10 +73,10 @@ def test_delayed_start():
...
@@ -71,10 +73,10 @@ def test_delayed_start():
assert
profiler
.
start_call_count
==
1
assert
profiler
.
start_call_count
==
1
def
test_max_iterations
():
def
test_max_iterations
(
default_profiler_config
):
"""Test that profiler stops automatically after max iterations."""
"""Test that profiler stops automatically after max iterations."""
envs
.
VLLM_PROFILER_MAX_ITERS
=
2
default_profiler_config
.
max_iterations
=
2
profiler
=
ConcreteWorkerProfiler
()
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
profiler
.
start
()
profiler
.
start
()
assert
profiler
.
_running
is
True
assert
profiler
.
_running
is
True
...
@@ -95,12 +97,11 @@ def test_max_iterations():
...
@@ -95,12 +97,11 @@ def test_max_iterations():
assert
profiler
.
stop_call_count
==
1
assert
profiler
.
stop_call_count
==
1
def
test_delayed_start_and_max_iters
():
def
test_delayed_start_and_max_iters
(
default_profiler_config
):
"""Test combined delayed start and max iterations."""
"""Test combined delayed start and max iterations."""
envs
.
VLLM_PROFILER_DELAY_ITERS
=
2
default_profiler_config
.
delay_iterations
=
2
envs
.
VLLM_PROFILER_MAX_ITERS
=
2
default_profiler_config
.
max_iterations
=
2
profiler
=
ConcreteWorkerProfiler
()
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
profiler
.
start
()
profiler
.
start
()
# Step 1
# Step 1
...
@@ -127,9 +128,9 @@ def test_delayed_start_and_max_iters():
...
@@ -127,9 +128,9 @@ def test_delayed_start_and_max_iters():
assert
profiler
.
stop_call_count
==
1
assert
profiler
.
stop_call_count
==
1
def
test_idempotency
():
def
test_idempotency
(
default_profiler_config
):
"""Test that calling start/stop multiple times doesn't break logic."""
"""Test that calling start/stop multiple times doesn't break logic."""
profiler
=
ConcreteWorkerProfiler
()
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
# Double Start
# Double Start
profiler
.
start
()
profiler
.
start
()
...
@@ -142,10 +143,10 @@ def test_idempotency():
...
@@ -142,10 +143,10 @@ def test_idempotency():
assert
profiler
.
stop_call_count
==
1
# Should only stop once
assert
profiler
.
stop_call_count
==
1
# Should only stop once
def
test_step_inactive
():
def
test_step_inactive
(
default_profiler_config
):
"""Test that stepping while inactive does nothing."""
"""Test that stepping while inactive does nothing."""
envs
.
VLLM_PROFILER_DELAY_ITERS
=
2
default_profiler_config
.
delay_iterations
=
2
profiler
=
ConcreteWorkerProfiler
()
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
# Not started yet
# Not started yet
profiler
.
step
()
profiler
.
step
()
...
@@ -155,9 +156,9 @@ def test_step_inactive():
...
@@ -155,9 +156,9 @@ def test_step_inactive():
assert
profiler
.
start_call_count
==
0
assert
profiler
.
start_call_count
==
0
def
test_start_failure
():
def
test_start_failure
(
default_profiler_config
):
"""Test behavior when the underlying _start method raises exception."""
"""Test behavior when the underlying _start method raises exception."""
profiler
=
ConcreteWorkerProfiler
()
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
profiler
.
should_fail_start
=
True
profiler
.
should_fail_start
=
True
profiler
.
start
()
profiler
.
start
()
...
@@ -168,9 +169,9 @@ def test_start_failure():
...
@@ -168,9 +169,9 @@ def test_start_failure():
assert
profiler
.
start_call_count
==
0
# Logic failed inside start
assert
profiler
.
start_call_count
==
0
# Logic failed inside start
def
test_shutdown
():
def
test_shutdown
(
default_profiler_config
):
"""Test that shutdown calls stop only if running."""
"""Test that shutdown calls stop only if running."""
profiler
=
ConcreteWorkerProfiler
()
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
# Case 1: Not running
# Case 1: Not running
profiler
.
shutdown
()
profiler
.
shutdown
()
...
@@ -182,10 +183,10 @@ def test_shutdown():
...
@@ -182,10 +183,10 @@ def test_shutdown():
assert
profiler
.
stop_call_count
==
1
assert
profiler
.
stop_call_count
==
1
def
test_mixed_delay_and_stop
():
def
test_mixed_delay_and_stop
(
default_profiler_config
):
"""Test manual stop during the delay period."""
"""Test manual stop during the delay period."""
envs
.
VLLM_PROFILER_DELAY_ITERS
=
5
default_profiler_config
.
delay_iterations
=
5
profiler
=
ConcreteWorkerProfiler
()
profiler
=
ConcreteWorkerProfiler
(
default_profiler_config
)
profiler
.
start
()
profiler
.
start
()
profiler
.
step
()
profiler
.
step
()
...
...
tools/ep_kernels/install_python_libraries.sh
View file @
8d75f22e
...
@@ -10,9 +10,10 @@ set -ex
...
@@ -10,9 +10,10 @@ set -ex
CUDA_HOME
=
${
CUDA_HOME
:-
/usr/local/cuda
}
CUDA_HOME
=
${
CUDA_HOME
:-
/usr/local/cuda
}
PPLX_COMMIT_HASH
=
${
PPLX_COMMIT_HASH
:-
"12cecfd"
}
PPLX_COMMIT_HASH
=
${
PPLX_COMMIT_HASH
:-
"12cecfd"
}
DEEPEP_COMMIT_HASH
=
${
DEEPEP_COMMIT_HASH
:-
"73b6ea4"
}
DEEPEP_COMMIT_HASH
=
${
DEEPEP_COMMIT_HASH
:-
"73b6ea4"
}
NVSHMEM_VER
=
3.3.
9
NVSHMEM_VER
=
3.3.
24
# Suppports both CUDA 12 and 13
WORKSPACE
=
${
WORKSPACE
:-
$(
pwd
)
/ep_kernels_workspace
}
WORKSPACE
=
${
WORKSPACE
:-
$(
pwd
)
/ep_kernels_workspace
}
MODE
=
${
MODE
:-
install
}
MODE
=
${
MODE
:-
install
}
CUDA_VERSION_MAJOR
=
$(
${
CUDA_HOME
}
/bin/nvcc
--version
| egrep
-o
"release [0-9]+"
|
cut
-d
' '
-f
2
)
# Parse arguments
# Parse arguments
while
[[
$#
-gt
0
]]
;
do
while
[[
$#
-gt
0
]]
;
do
...
@@ -75,11 +76,9 @@ ARCH=$(uname -m)
...
@@ -75,11 +76,9 @@ ARCH=$(uname -m)
case
"
${
ARCH
,,
}
"
in
case
"
${
ARCH
,,
}
"
in
x86_64|amd64
)
x86_64|amd64
)
NVSHMEM_SUBDIR
=
"linux-x86_64"
NVSHMEM_SUBDIR
=
"linux-x86_64"
NVSHMEM_FILE
=
"libnvshmem-linux-x86_64-
${
NVSHMEM_VER
}
_cuda12-archive.tar.xz"
;;
;;
aarch64|arm64
)
aarch64|arm64
)
NVSHMEM_SUBDIR
=
"linux-sbsa"
NVSHMEM_SUBDIR
=
"linux-sbsa"
NVSHMEM_FILE
=
"libnvshmem-linux-sbsa-
${
NVSHMEM_VER
}
_cuda12-archive.tar.xz"
;;
;;
*
)
*
)
echo
"Unsupported architecture:
${
ARCH
}
"
>
&2
echo
"Unsupported architecture:
${
ARCH
}
"
>
&2
...
@@ -87,6 +86,7 @@ case "${ARCH,,}" in
...
@@ -87,6 +86,7 @@ case "${ARCH,,}" in
;;
;;
esac
esac
NVSHMEM_FILE
=
"libnvshmem-
${
NVSHMEM_SUBDIR
}
-
${
NVSHMEM_VER
}
_cuda
${
CUDA_VERSION_MAJOR
}
-archive.tar.xz"
NVSHMEM_URL
=
"https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/
${
NVSHMEM_SUBDIR
}
/
${
NVSHMEM_FILE
}
"
NVSHMEM_URL
=
"https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/
${
NVSHMEM_SUBDIR
}
/
${
NVSHMEM_FILE
}
"
pushd
"
$WORKSPACE
"
pushd
"
$WORKSPACE
"
...
@@ -142,13 +142,6 @@ clone_repo() {
...
@@ -142,13 +142,6 @@ clone_repo() {
fi
fi
}
}
deepep_cuda13_patch
()
{
cuda_version_major
=
$(
${
CUDA_HOME
}
/bin/nvcc
--version
| egrep
-o
"release [0-9]+"
|
cut
-d
' '
-f
2
)
if
[
${
cuda_version_major
}
-ge
13
]
;
then
sed
-i
"s|f'{nvshmem_dir}/include']|f'{nvshmem_dir}/include', '
${
CUDA_HOME
}
/include/cccl']|"
"setup.py"
fi
}
do_build
()
{
do_build
()
{
local
repo
=
$1
local
repo
=
$1
local
name
=
$2
local
name
=
$2
...
@@ -160,8 +153,9 @@ do_build() {
...
@@ -160,8 +153,9 @@ do_build() {
clone_repo
"
$repo
"
"
$name
"
"
$key
"
"
$commit
"
clone_repo
"
$repo
"
"
$name
"
"
$key
"
"
$commit
"
cd
"
$name
"
cd
"
$name
"
if
[
"
$name
"
==
"DeepEP"
]
;
then
# DeepEP CUDA 13 patch
deepep_cuda13_patch
if
[[
"
$name
"
==
"DeepEP"
&&
"
${
CUDA_VERSION_MAJOR
}
"
-ge
13
]]
;
then
sed
-i
"s|f'{nvshmem_dir}/include']|f'{nvshmem_dir}/include', '
${
CUDA_HOME
}
/include/cccl']|"
"setup.py"
fi
fi
if
[
"
$MODE
"
=
"install"
]
;
then
if
[
"
$MODE
"
=
"install"
]
;
then
...
...
use_existing_torch.py
View file @
8d75f22e
...
@@ -3,9 +3,7 @@
...
@@ -3,9 +3,7 @@
import
glob
import
glob
requires_files
=
glob
.
glob
(
"requirements/*.txt"
)
for
file
in
(
*
glob
.
glob
(
"requirements/*.txt"
),
"pyproject.toml"
):
requires_files
+=
[
"pyproject.toml"
]
for
file
in
requires_files
:
print
(
f
">>> cleaning
{
file
}
"
)
print
(
f
">>> cleaning
{
file
}
"
)
with
open
(
file
)
as
f
:
with
open
(
file
)
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
...
@@ -17,5 +15,4 @@ for file in requires_files:
...
@@ -17,5 +15,4 @@ for file in requires_files:
f
.
write
(
line
)
f
.
write
(
line
)
else
:
else
:
print
(
line
.
strip
())
print
(
line
.
strip
())
print
(
f
"<<< done cleaning
{
file
}
"
)
print
(
f
"<<< done cleaning
{
file
}
\n
"
)
print
()
vllm/_aiter_ops.py
View file @
8d75f22e
...
@@ -9,6 +9,8 @@ import vllm.envs as envs
...
@@ -9,6 +9,8 @@ import vllm.envs as envs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
,
is_torch_equal_or_newer
from
vllm.utils.torch_utils
import
direct_register_custom_op
,
is_torch_equal_or_newer
_FP8_DTYPE
=
current_platform
.
fp8_dtype
()
def
is_aiter_found
()
->
bool
:
def
is_aiter_found
()
->
bool
:
from
importlib.util
import
find_spec
from
importlib.util
import
find_spec
...
@@ -22,6 +24,15 @@ def is_aiter_found() -> bool:
...
@@ -22,6 +24,15 @@ def is_aiter_found() -> bool:
# we keep this global outside to not cause torch compile breaks.
# we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND
=
is_aiter_found
()
IS_AITER_FOUND
=
is_aiter_found
()
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if
IS_AITER_FOUND
:
from
aiter
import
dtypes
AITER_FP8_DTYPE
=
dtypes
.
fp8
def
if_aiter_supported
(
func
:
Callable
)
->
Callable
:
def
if_aiter_supported
(
func
:
Callable
)
->
Callable
:
"""Decorator that only executes the function if
"""Decorator that only executes the function if
...
@@ -43,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable:
...
@@ -43,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable:
return
wrapper
return
wrapper
def
_rocm_aiter_group_fp8_quant_impl
(
x
:
torch
.
Tensor
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
shape
[
-
1
]
%
group_size
==
0
,
"Input shape must be divisible by group size"
from
aiter
import
QuantType
,
dtypes
,
get_hip_quant
aiter_per1x128_quant
=
get_hip_quant
(
QuantType
.
per_1x128
)
return
aiter_per1x128_quant
(
x
.
contiguous
(),
quant_dtype
=
dtypes
.
fp8
)
def
_rocm_aiter_group_fp8_quant_fake
(
x
:
torch
.
Tensor
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter
import
dtypes
M
,
N
=
x
.
shape
x_fp8
=
torch
.
empty
((
M
,
N
),
dtype
=
dtypes
.
fp8
,
device
=
x
.
device
)
out_bs
=
torch
.
empty
(
(
M
,
(
N
+
group_size
-
1
)
//
group_size
,
),
dtype
=
torch
.
float32
,
device
=
x
.
device
,
)
return
x_fp8
,
out_bs
def
_rocm_aiter_fused_moe_impl
(
def
_rocm_aiter_fused_moe_impl
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
@@ -283,6 +264,28 @@ def _rocm_aiter_grouped_topk_fake(
...
@@ -283,6 +264,28 @@ def _rocm_aiter_grouped_topk_fake(
pass
pass
# Cache whether aiter supports FP8 MLA parameters
_AITER_MLA_SUPPORTS_FP8
:
bool
|
None
=
None
def
_check_aiter_mla_fp8_support
()
->
bool
:
"""Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters."""
global
_AITER_MLA_SUPPORTS_FP8
if
_AITER_MLA_SUPPORTS_FP8
is
None
:
try
:
import
inspect
from
aiter.mla
import
mla_decode_fwd
sig
=
inspect
.
signature
(
mla_decode_fwd
)
_AITER_MLA_SUPPORTS_FP8
=
(
"q_scale"
in
sig
.
parameters
and
"kv_scale"
in
sig
.
parameters
)
except
Exception
:
_AITER_MLA_SUPPORTS_FP8
=
False
return
_AITER_MLA_SUPPORTS_FP8
def
_rocm_aiter_mla_decode_fwd_impl
(
def
_rocm_aiter_mla_decode_fwd_impl
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
kv_buffer
:
torch
.
Tensor
,
kv_buffer
:
torch
.
Tensor
,
...
@@ -299,6 +302,16 @@ def _rocm_aiter_mla_decode_fwd_impl(
...
@@ -299,6 +302,16 @@ def _rocm_aiter_mla_decode_fwd_impl(
)
->
None
:
)
->
None
:
from
aiter.mla
import
mla_decode_fwd
from
aiter.mla
import
mla_decode_fwd
kwargs
=
{
"sm_scale"
:
sm_scale
,
"logit_cap"
:
logit_cap
,
}
# Only pass q_scale and kv_scale if the aiter library supports them
if
_check_aiter_mla_fp8_support
():
kwargs
[
"q_scale"
]
=
q_scale
kwargs
[
"kv_scale"
]
=
kv_scale
mla_decode_fwd
(
mla_decode_fwd
(
q
,
q
,
kv_buffer
.
view
(
-
1
,
1
,
1
,
q
.
shape
[
-
1
]),
kv_buffer
.
view
(
-
1
,
1
,
1
,
q
.
shape
[
-
1
]),
...
@@ -308,10 +321,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
...
@@ -308,10 +321,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
kv_indices
,
kv_indices
,
kv_last_page_lens
,
kv_last_page_lens
,
max_seqlen_qo
,
max_seqlen_qo
,
sm_scale
=
sm_scale
,
**
kwargs
,
logit_cap
=
logit_cap
,
q_scale
=
q_scale
,
kv_scale
=
kv_scale
,
)
)
...
@@ -438,6 +448,195 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
...
@@ -438,6 +448,195 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
return
torch
.
empty_like
(
x
),
torch
.
empty_like
(
residual
)
return
torch
.
empty_like
(
x
),
torch
.
empty_like
(
residual
)
def
_rocm_aiter_per_tensor_quant_impl
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter.ops.quant
import
per_tensor_quant_hip
return
per_tensor_quant_hip
(
x
,
scale
,
quant_dtype
)
def
_rocm_aiter_per_tensor_quant_fake
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
empty_like
(
x
,
dtype
=
quant_dtype
),
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
x
.
device
)
def
_rocm_aiter_per_token_quant_impl
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter.ops.quant
import
dynamic_per_token_scaled_quant
assert
quant_dtype
in
[
torch
.
int8
,
_FP8_DTYPE
]
out_shape
=
x
.
shape
out
=
torch
.
empty
(
x
.
shape
,
dtype
=
_FP8_DTYPE
,
device
=
x
.
device
)
if
scale
is
None
:
scale
=
torch
.
empty
((
*
out_shape
[:
-
1
],
1
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
dynamic_per_token_scaled_quant
(
out
,
x
,
scale
,
scale_ub
=
None
,
shuffle_scale
=
False
,
num_rows
=
None
,
num_rows_factor
=
1
,
)
return
out
,
scale
def
_rocm_aiter_per_token_quant_fake
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
out_shape
=
x
.
shape
return
(
torch
.
empty
(
x
.
shape
,
dtype
=
_FP8_DTYPE
,
device
=
x
.
device
),
torch
.
empty
((
*
out_shape
[:
-
1
],
1
),
dtype
=
torch
.
float32
,
device
=
x
.
device
),
)
def
_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl
(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter.ops.triton.fused_fp8_quant
import
fused_rms_fp8_group_quant
(
x_quant
,
x_quant_scales
),
_
,
_
,
res
=
fused_rms_fp8_group_quant
(
x
,
weight
,
variance_epsilon
,
None
,
None
,
None
,
group_size
=
group_size
,
dtype_quant
=
AITER_FP8_DTYPE
,
res1
=
residual
,
)
return
(
x_quant
,
x_quant_scales
,
res
)
def
_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake
(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
M
,
N
=
x
.
shape
scale_shape
=
(
M
,
(
N
+
group_size
-
1
)
//
group_size
)
return
(
torch
.
empty_like
(
x
,
dtype
=
AITER_FP8_DTYPE
,
device
=
x
.
device
),
torch
.
empty
(
scale_shape
,
dtype
=
torch
.
float32
,
device
=
x
.
device
),
torch
.
empty_like
(
residual
,
device
=
residual
.
device
),
)
def
_rocm_aiter_rmsnorm_fp8_group_quant_impl
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter.ops.triton.fused_fp8_quant
import
fused_rms_fp8_group_quant
(
x_quant
,
x_quant_scales
),
_
,
_
,
res
=
fused_rms_fp8_group_quant
(
x
,
weight
,
variance_epsilon
,
None
,
None
,
None
,
group_size
=
group_size
,
dtype_quant
=
AITER_FP8_DTYPE
,
res1
=
None
,
)
return
(
x_quant
,
x_quant_scales
)
def
_rocm_aiter_rmsnorm_fp8_group_quant_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
M
,
N
=
x
.
shape
scale_shape
=
(
M
,
(
N
+
group_size
-
1
)
//
group_size
)
return
(
torch
.
empty_like
(
x
,
dtype
=
AITER_FP8_DTYPE
,
device
=
x
.
device
),
torch
.
empty
(
scale_shape
,
dtype
=
torch
.
float32
,
device
=
x
.
device
),
)
def
_rocm_aiter_group_fp8_quant_impl
(
x
:
torch
.
Tensor
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
x
.
shape
[
-
1
]
%
group_size
==
0
,
"Input shape must be divisible by group size"
from
aiter
import
QuantType
,
get_hip_quant
aiter_per1x128_quant
=
get_hip_quant
(
QuantType
.
per_1x128
)
return
aiter_per1x128_quant
(
x
.
contiguous
(),
quant_dtype
=
AITER_FP8_DTYPE
)
def
_rocm_aiter_group_fp8_quant_fake
(
x
:
torch
.
Tensor
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
M
,
N
=
x
.
shape
x_fp8
=
torch
.
empty
((
M
,
N
),
dtype
=
AITER_FP8_DTYPE
,
device
=
x
.
device
)
out_bs
=
torch
.
empty
(
(
M
,
(
N
+
group_size
-
1
)
//
group_size
,
),
dtype
=
torch
.
float32
,
device
=
x
.
device
,
)
return
x_fp8
,
out_bs
def
_rocm_aiter_act_mul_and_fp8_group_quant_impl
(
x
:
torch
.
Tensor
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter.ops.triton.activation
import
act_mul_and_fp8_group_quant
return
act_mul_and_fp8_group_quant
(
x
,
activation
=
"silu"
,
group_size
=
group_size
,
dtype_quant
=
AITER_FP8_DTYPE
,
)
def
_rocm_aiter_act_mul_and_fp8_group_quant_fake
(
x
:
torch
.
Tensor
,
group_size
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
M
,
N
=
x
.
shape
assert
N
%
2
==
0
N_half
=
N
//
2
x_fp8
=
torch
.
empty
((
M
,
N_half
),
dtype
=
AITER_FP8_DTYPE
,
device
=
x
.
device
)
out_bs
=
torch
.
empty
(
(
M
,
(
N_half
+
group_size
-
1
)
//
group_size
,
),
dtype
=
torch
.
float32
,
device
=
x
.
device
,
)
return
x_fp8
,
out_bs
# Global flag to ensure ops are registered only once
# Global flag to ensure ops are registered only once
_OPS_REGISTERED
=
False
_OPS_REGISTERED
=
False
...
@@ -473,7 +672,7 @@ class rocm_aiter_ops:
...
@@ -473,7 +672,7 @@ class rocm_aiter_ops:
@
if_aiter_supported
@
if_aiter_supported
def
is_linear_fp8_enaled
(
cls
)
->
bool
:
def
is_linear_fp8_enaled
(
cls
)
->
bool
:
""" "Verifies device specs and availability of env variable."""
""" "Verifies device specs and availability of env variable."""
return
cls
.
is_linear_enabled
()
and
current_platform
.
is_fp8_fnuz
()
return
cls
.
is_linear_enabled
()
@
classmethod
@
classmethod
@
if_aiter_supported
@
if_aiter_supported
...
@@ -548,14 +747,6 @@ class rocm_aiter_ops:
...
@@ -548,14 +747,6 @@ class rocm_aiter_ops:
)
)
# register all the custom ops here
# register all the custom ops here
direct_register_custom_op
(
op_name
=
"rocm_aiter_group_fp8_quant"
,
op_func
=
_rocm_aiter_group_fp8_quant_impl
,
mutates_args
=
[],
fake_impl
=
_rocm_aiter_group_fp8_quant_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"rocm_aiter_asm_moe_tkw1"
,
op_name
=
"rocm_aiter_asm_moe_tkw1"
,
op_func
=
_rocm_aiter_asm_moe_tkw1_impl
,
op_func
=
_rocm_aiter_asm_moe_tkw1_impl
,
...
@@ -615,27 +806,62 @@ class rocm_aiter_ops:
...
@@ -615,27 +806,62 @@ class rocm_aiter_ops:
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"rocm_aiter_gemm_a8w8_blockscale"
,
op_name
=
"rocm_aiter_gemm_a8w8_blockscale"
,
op_func
=
_rocm_aiter_gemm_a8w8_blockscale_impl
,
op_func
=
_rocm_aiter_gemm_a8w8_blockscale_impl
,
mutates_args
=
[],
fake_impl
=
_rocm_aiter_gemm_a8w8_blockscale_fake
,
fake_impl
=
_rocm_aiter_gemm_a8w8_blockscale_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
)
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"rocm_aiter_rms_norm"
,
op_name
=
"rocm_aiter_rms_norm"
,
op_func
=
_rocm_aiter_rms_norm_impl
,
op_func
=
_rocm_aiter_rms_norm_impl
,
mutates_args
=
[],
fake_impl
=
_rocm_aiter_rms_norm_fake
,
fake_impl
=
_rocm_aiter_rms_norm_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
)
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"rocm_aiter_rmsnorm2d_fwd_with_add"
,
op_name
=
"rocm_aiter_rmsnorm2d_fwd_with_add"
,
op_func
=
_rocm_aiter_rmsnorm2d_fwd_with_add_impl
,
op_func
=
_rocm_aiter_rmsnorm2d_fwd_with_add_impl
,
mutates_args
=
[],
fake_impl
=
_rocm_aiter_rmsnorm2d_fwd_with_add_fake
,
fake_impl
=
_rocm_aiter_rmsnorm2d_fwd_with_add_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_rmsnorm_fp8_group_quant"
,
op_func
=
_rocm_aiter_rmsnorm_fp8_group_quant_impl
,
fake_impl
=
_rocm_aiter_rmsnorm_fp8_group_quant_fake
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_rmsnorm_with_add_fp8_group_quant"
,
op_func
=
_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl
,
fake_impl
=
_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_act_mul_and_fp8_group_quant"
,
op_func
=
_rocm_aiter_act_mul_and_fp8_group_quant_impl
,
fake_impl
=
_rocm_aiter_act_mul_and_fp8_group_quant_fake
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_group_fp8_quant"
,
op_func
=
_rocm_aiter_group_fp8_quant_impl
,
fake_impl
=
_rocm_aiter_group_fp8_quant_fake
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_per_tensor_quant"
,
op_func
=
_rocm_aiter_per_tensor_quant_impl
,
mutates_args
=
[],
fake_impl
=
_rocm_aiter_per_tensor_quant_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_per_token_quant"
,
op_func
=
_rocm_aiter_per_token_quant_impl
,
mutates_args
=
[
"scale"
],
fake_impl
=
_rocm_aiter_per_token_quant_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
_OPS_REGISTERED
=
True
_OPS_REGISTERED
=
True
@
staticmethod
@
staticmethod
...
@@ -830,6 +1056,22 @@ class rocm_aiter_ops:
...
@@ -830,6 +1056,22 @@ class rocm_aiter_ops:
kv_scale
=
kv_scale
,
kv_scale
=
kv_scale
,
)
)
@
staticmethod
def
per_tensor_quant
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
vllm
.
rocm_aiter_per_tensor_quant
(
x
,
quant_dtype
,
scale
)
@
staticmethod
def
per_token_quant
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
vllm
.
rocm_aiter_per_token_quant
(
x
,
quant_dtype
,
scale
)
@
staticmethod
@
staticmethod
def
triton_fp4_gemm_dynamic_qaunt
(
def
triton_fp4_gemm_dynamic_qaunt
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
vllm/_custom_ops.py
View file @
8d75f22e
...
@@ -441,6 +441,46 @@ def rms_norm_dynamic_per_token_quant(
...
@@ -441,6 +441,46 @@ def rms_norm_dynamic_per_token_quant(
return
output
,
scales
return
output
,
scales
# fused quant layer norm ops blocked
def
rms_norm_per_block_quant
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
,
quant_dtype
:
torch
.
dtype
,
group_size
:
list
[
int
],
scale_ub
:
torch
.
Tensor
|
None
=
None
,
residual
:
torch
.
Tensor
|
None
=
None
,
is_scale_transposed
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
len
(
group_size
)
==
2
output
=
torch
.
empty_like
(
input
,
dtype
=
quant_dtype
)
if
is_scale_transposed
:
scales
=
torch
.
empty
(
(
input
.
shape
[
-
1
]
//
group_size
[
1
],
input
.
numel
()
//
input
.
shape
[
-
1
]),
device
=
input
.
device
,
dtype
=
torch
.
float32
,
).
transpose
(
0
,
1
)
else
:
scales
=
torch
.
empty
(
(
input
.
numel
()
//
input
.
shape
[
-
1
],
input
.
shape
[
-
1
]
//
group_size
[
1
]),
device
=
input
.
device
,
dtype
=
torch
.
float32
,
)
torch
.
ops
.
_C
.
rms_norm_per_block_quant
(
output
,
input
,
weight
,
scales
,
epsilon
,
scale_ub
,
residual
,
group_size
[
1
],
is_scale_transposed
,
)
return
output
,
scales
# quantization ops
# quantization ops
# awq
# awq
def
awq_dequantize
(
def
awq_dequantize
(
...
@@ -660,6 +700,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
...
@@ -660,6 +700,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def
cutlass_encode_and_reorder_int4b_fake
(
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
cutlass_encode_and_reorder_int4b_fake
(
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
b
,
memory_format
=
torch
.
contiguous_format
)
return
torch
.
empty_like
(
b
,
memory_format
=
torch
.
contiguous_format
)
@
register_fake
(
"_C::cutlass_encode_and_reorder_int4b_grouped"
)
def
cutlass_encode_and_reorder_int4b_grouped_fake
(
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
b
,
memory_format
=
torch
.
contiguous_format
)
if
hasattr
(
torch
.
ops
.
_C
,
"allspark_w8a16_gemm"
):
if
hasattr
(
torch
.
ops
.
_C
,
"allspark_w8a16_gemm"
):
...
@@ -1023,6 +1067,7 @@ def get_cutlass_moe_mm_problem_sizes(
...
@@ -1023,6 +1067,7 @@ def get_cutlass_moe_mm_problem_sizes(
n
:
int
,
n
:
int
,
k
:
int
,
k
:
int
,
blockscale_offsets
:
torch
.
Tensor
|
None
=
None
,
blockscale_offsets
:
torch
.
Tensor
|
None
=
None
,
force_swap_ab
:
bool
|
None
=
None
,
):
):
"""
"""
Compute only the per-expert problem sizes needed by the two grouped matrix
Compute only the per-expert problem sizes needed by the two grouped matrix
...
@@ -1032,9 +1077,20 @@ def get_cutlass_moe_mm_problem_sizes(
...
@@ -1032,9 +1077,20 @@ def get_cutlass_moe_mm_problem_sizes(
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
multiplication for the two grouped MMs
multiplication for the two grouped MMs
used in the fused MoE operation.
used in the fused MoE operation.
Optional:
- force_swap_ab: If set to True or False, explicitly enable or disable the
A/B input swap optimization. If None (default), the swap
is selected automatically based on tensor sizes.
"""
"""
return
torch
.
ops
.
_C
.
get_cutlass_moe_mm_problem_sizes
(
return
torch
.
ops
.
_C
.
get_cutlass_moe_mm_problem_sizes
(
topk_ids
,
problem_sizes1
,
problem_sizes2
,
num_experts
,
n
,
k
,
blockscale_offsets
topk_ids
,
problem_sizes1
,
problem_sizes2
,
num_experts
,
n
,
k
,
blockscale_offsets
,
force_swap_ab
,
)
)
...
@@ -1422,6 +1478,78 @@ def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor:
...
@@ -1422,6 +1478,78 @@ def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor:
return
torch
.
ops
.
_C
.
cutlass_encode_and_reorder_int4b
(
b
)
return
torch
.
ops
.
_C
.
cutlass_encode_and_reorder_int4b
(
b
)
def
cutlass_w4a8_moe_mm
(
out_tensors
:
torch
.
Tensor
,
a_tensors
:
torch
.
Tensor
,
b_tensors
:
torch
.
Tensor
,
a_scales
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_group_scales
:
torch
.
Tensor
,
b_group_size
:
int
,
expert_offsets
:
torch
.
Tensor
,
problem_sizes
:
torch
.
Tensor
,
a_strides
:
torch
.
Tensor
,
b_strides
:
torch
.
Tensor
,
c_strides
:
torch
.
Tensor
,
group_scale_strides
:
torch
.
Tensor
,
maybe_schedule
:
str
|
None
=
None
,
):
"""
Executes the CUTLASS-based fused-MoE grouped matrix multiplication for the
W4A8 quantization scheme. Uses group-wise quantization (INT4 -> FP8)
and both per-channel + per-token scaling in the epilogue.
Args:
out_tensors:
Output buffer for all experts (updated in-place).
a_tensors:
FP8 (E4M3FN) activations for all experts.
b_tensors:
INT4-packed weight matrix for all experts, packed to INT32
a_scales:
Per-token FP8 activation scales, applied in the epilogue.
b_scales:
Per-channel FP8 weight scales for each expert, applied in the epilogue.
b_group_scales:
FP8 scale values for group-wise INT4 weight blocks.
b_group_size:
Number of elements grouped under each entry of b_group_scales.
expert_offsets:
Cumulative token offsets
problem_sizes:
Per-expert (M, N, K) GEMM sizes used by the grouped GEMM launcher.
a/b/c/group_scale_strides:
Strides describing the memory layout of the input tensors.
maybe_schedule:
Optional override to choose a specific kernel or epilogue schedule.
Returns:
out_tensors updated in-place with the dequantized INT4xFP8 grouped GEMM result.
"""
return
torch
.
ops
.
_C
.
cutlass_w4a8_moe_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
b_group_scales
,
b_group_size
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
,
group_scale_strides
,
maybe_schedule
,
)
def
cutlass_encode_and_reorder_int4b_grouped
(
b_tensors
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
_C
.
cutlass_encode_and_reorder_int4b_grouped
(
b_tensors
)
if
hasattr
(
torch
.
ops
.
_C
,
"permute_cols"
):
if
hasattr
(
torch
.
ops
.
_C
,
"permute_cols"
):
@
register_fake
(
"_C::permute_cols"
)
@
register_fake
(
"_C::permute_cols"
)
...
@@ -1603,7 +1731,7 @@ def scaled_fp8_quant(
...
@@ -1603,7 +1731,7 @@ def scaled_fp8_quant(
output
,
input
,
scale
,
scale_ub
output
,
input
,
scale
,
scale_ub
)
)
else
:
else
:
scale
=
torch
.
empty
(
(
1
,
1
)
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
scale
=
torch
.
empty
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
else
:
assert
scale
.
numel
()
==
1
,
f
"
{
scale
.
shape
}
"
assert
scale
.
numel
()
==
1
,
f
"
{
scale
.
shape
}
"
...
@@ -1882,6 +2010,7 @@ def moe_align_block_size(
...
@@ -1882,6 +2010,7 @@ def moe_align_block_size(
sorted_token_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
experts_ids
:
torch
.
Tensor
,
experts_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
)
->
None
:
torch
.
ops
.
_moe_C
.
moe_align_block_size
(
torch
.
ops
.
_moe_C
.
moe_align_block_size
(
topk_ids
,
topk_ids
,
...
@@ -1890,6 +2019,7 @@ def moe_align_block_size(
...
@@ -1890,6 +2019,7 @@ def moe_align_block_size(
sorted_token_ids
,
sorted_token_ids
,
experts_ids
,
experts_ids
,
num_tokens_post_pad
,
num_tokens_post_pad
,
expert_map
,
)
)
...
@@ -1924,6 +2054,7 @@ def moe_lora_align_block_size(
...
@@ -1924,6 +2054,7 @@ def moe_lora_align_block_size(
num_tokens_post_pad
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
adapter_enabled
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
lora_ids
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
)
->
None
:
torch
.
ops
.
_moe_C
.
moe_lora_align_block_size
(
torch
.
ops
.
_moe_C
.
moe_lora_align_block_size
(
topk_ids
,
topk_ids
,
...
@@ -1938,6 +2069,7 @@ def moe_lora_align_block_size(
...
@@ -1938,6 +2069,7 @@ def moe_lora_align_block_size(
num_tokens_post_pad
,
num_tokens_post_pad
,
adapter_enabled
,
adapter_enabled
,
lora_ids
,
lora_ids
,
expert_map
,
)
)
...
...
vllm/attention/backends/abstract.py
View file @
8d75f22e
...
@@ -166,6 +166,10 @@ class AttentionBackend(ABC):
...
@@ -166,6 +166,10 @@ class AttentionBackend(ABC):
def
supports_sink
(
cls
)
->
bool
:
def
supports_sink
(
cls
)
->
bool
:
return
False
return
False
@
classmethod
def
supports_mm_prefix
(
cls
)
->
bool
:
return
False
@
classmethod
@
classmethod
def
is_sparse
(
cls
)
->
bool
:
def
is_sparse
(
cls
)
->
bool
:
return
False
return
False
...
@@ -207,6 +211,7 @@ class AttentionBackend(ABC):
...
@@ -207,6 +211,7 @@ class AttentionBackend(ABC):
use_mla
:
bool
,
use_mla
:
bool
,
has_sink
:
bool
,
has_sink
:
bool
,
use_sparse
:
bool
,
use_sparse
:
bool
,
use_mm_prefix
:
bool
,
device_capability
:
"DeviceCapability"
,
device_capability
:
"DeviceCapability"
,
attn_type
:
str
,
attn_type
:
str
,
)
->
list
[
str
]:
)
->
list
[
str
]:
...
@@ -219,6 +224,10 @@ class AttentionBackend(ABC):
...
@@ -219,6 +224,10 @@ class AttentionBackend(ABC):
invalid_reasons
.
append
(
"kv_cache_dtype not supported"
)
invalid_reasons
.
append
(
"kv_cache_dtype not supported"
)
if
not
cls
.
supports_block_size
(
block_size
):
if
not
cls
.
supports_block_size
(
block_size
):
invalid_reasons
.
append
(
"block_size not supported"
)
invalid_reasons
.
append
(
"block_size not supported"
)
if
use_mm_prefix
and
not
cls
.
supports_mm_prefix
():
invalid_reasons
.
append
(
"partial multimodal token full attention not supported"
)
if
use_mla
!=
cls
.
is_mla
():
if
use_mla
!=
cls
.
is_mla
():
if
use_mla
:
if
use_mla
:
invalid_reasons
.
append
(
"MLA not supported"
)
invalid_reasons
.
append
(
"MLA not supported"
)
...
@@ -289,6 +298,16 @@ class AttentionImpl(ABC, Generic[T]):
...
@@ -289,6 +298,16 @@ class AttentionImpl(ABC, Generic[T]):
# even if they can return lse (for efficiency reasons)
# even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode
:
bool
=
False
need_to_return_lse_for_decode
:
bool
=
False
# Whether this attention implementation supports pre-quantized query input.
# When True, the attention layer will quantize queries before passing them
# to this backend, allowing torch.compile to fuse the quantization with
# previous operations. This is typically supported when using FP8 KV cache
# with compatible attention kernels (e.g., TRT-LLM).
# Subclasses should set this in __init__.
# TODO add support to more backends:
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input
:
bool
=
False
dcp_world_size
:
int
dcp_world_size
:
int
dcp_rank
:
int
dcp_rank
:
int
...
@@ -368,22 +387,6 @@ class AttentionImpl(ABC, Generic[T]):
...
@@ -368,22 +387,6 @@ class AttentionImpl(ABC, Generic[T]):
"""
"""
return
False
return
False
def
supports_quant_query_input
(
self
)
->
bool
:
"""
Check if this attention implementation supports pre-quantized query input.
When True, the attention layer will quantize queries before passing them
to this backend, allowing torch.compile to fuse the quantization with
previous operations. This is typically supported when using FP8 KV cache
with compatible attention kernels (e.g., TRT-LLM).
TODO add support to more backends:
https://github.com/vllm-project/vllm/issues/25584
Returns:
bool: True if the implementation can accept pre-quantized queries.
"""
return
False
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
pass
pass
...
...
vllm/attention/layer.py
View file @
8d75f22e
...
@@ -25,6 +25,7 @@ from vllm.config.vllm import VllmConfig
...
@@ -25,6 +25,7 @@ from vllm.config.vllm import VllmConfig
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.batch_invariant
import
vllm_is_batch_invariant
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
UnquantizedLinearMethod
,
UnquantizedLinearMethod
,
...
@@ -88,7 +89,10 @@ def maybe_get_vit_flash_attn_backend(
...
@@ -88,7 +89,10 @@ def maybe_get_vit_flash_attn_backend(
if
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
:
if
attn_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
:
from
aiter
import
flash_attn_varlen_func
from
aiter
import
flash_attn_varlen_func
else
:
else
:
from
vllm.attention.utils.fa_utils
import
flash_attn_varlen_func
try
:
from
vllm.attention.utils.fa_utils
import
flash_attn_varlen_func
except
ImportError
:
flash_attn_varlen_func
=
None
else
:
else
:
flash_attn_varlen_func
=
None
flash_attn_varlen_func
=
None
...
@@ -230,6 +234,10 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -230,6 +234,10 @@ class Attention(nn.Module, AttentionLayerBase):
self
.
sliding_window
=
sliding_window
self
.
sliding_window
=
sliding_window
self
.
has_sink
=
extra_impl_args
.
get
(
"sinks"
)
is
not
None
self
.
has_sink
=
extra_impl_args
.
get
(
"sinks"
)
is
not
None
# NOTE: model_config may be None during certain tests
model_config
=
vllm_config
.
model_config
self
.
use_mm_prefix
=
model_config
is
not
None
and
model_config
.
is_mm_prefix_lm
# During model initialization, the default dtype is set as the model
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
# weight and activation dtype.
dtype
=
torch
.
get_default_dtype
()
dtype
=
torch
.
get_default_dtype
()
...
@@ -241,11 +249,30 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -241,11 +249,30 @@ class Attention(nn.Module, AttentionLayerBase):
block_size
,
block_size
,
use_mla
=
False
,
use_mla
=
False
,
has_sink
=
self
.
has_sink
,
has_sink
=
self
.
has_sink
,
use_mm_prefix
=
self
.
use_mm_prefix
,
attn_type
=
attn_type
,
attn_type
=
attn_type
,
)
)
else
:
else
:
self
.
attn_backend
=
attn_backend
self
.
attn_backend
=
attn_backend
# prefix caching + batch invariance is currently not supported for
# FLASHINFER and TRITON_MLA.
if
(
cache_config
is
not
None
and
cache_config
.
enable_prefix_caching
and
vllm_is_batch_invariant
()
and
(
self
.
attn_backend
.
get_name
()
==
"FLASHINFER"
or
self
.
attn_backend
.
get_name
()
==
"TRITON_MLA"
)
):
logger
.
warning_once
(
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
"with batch invariance, as it is not yet supported."
,
scope
=
"local"
,
)
cache_config
.
enable_prefix_caching
=
False
impl_cls
=
self
.
attn_backend
.
get_impl_cls
()
impl_cls
=
self
.
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
self
.
impl
=
impl_cls
(
num_heads
,
num_heads
,
...
@@ -303,7 +330,7 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -303,7 +330,7 @@ class Attention(nn.Module, AttentionLayerBase):
self
.
query_quant
=
None
self
.
query_quant
=
None
if
(
if
(
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
and
self
.
impl
.
supports_quant_query_input
()
and
self
.
impl
.
supports_quant_query_input
):
):
self
.
query_quant
=
QuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
)
self
.
query_quant
=
QuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
)
...
@@ -338,7 +365,7 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -338,7 +365,7 @@ class Attention(nn.Module, AttentionLayerBase):
assert
self
.
kv_cache_dtype
in
{
"fp8"
,
"fp8_e4m3"
}
assert
self
.
kv_cache_dtype
in
{
"fp8"
,
"fp8_e4m3"
}
# check if query quantization is supported
# check if query quantization is supported
if
self
.
impl
.
supports_quant_query_input
()
:
if
self
.
impl
.
supports_quant_query_input
:
query
,
_
=
self
.
query_quant
(
query
,
self
.
_q_scale
)
query
,
_
=
self
.
query_quant
(
query
,
self
.
_q_scale
)
if
self
.
use_output
:
if
self
.
use_output
:
...
@@ -623,6 +650,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -623,6 +650,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
use_mla
=
True
,
use_mla
=
True
,
use_sparse
=
use_sparse
,
use_sparse
=
use_sparse
,
)
)
if
(
cache_config
is
not
None
and
cache_config
.
enable_prefix_caching
and
vllm_is_batch_invariant
()
and
(
self
.
attn_backend
.
get_name
()
==
"TRITON_MLA"
or
self
.
attn_backend
.
get_name
()
==
"FLASHINFER"
)
):
logger
.
warning_once
(
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
"with batch invariance, as it is not yet supported."
,
scope
=
"local"
,
)
cache_config
.
enable_prefix_caching
=
False
impl_cls
=
cast
(
type
[
MLAAttentionImpl
],
self
.
attn_backend
.
get_impl_cls
())
impl_cls
=
cast
(
type
[
MLAAttentionImpl
],
self
.
attn_backend
.
get_impl_cls
())
self
.
impl
=
impl_cls
(
self
.
impl
=
impl_cls
(
num_heads
=
self
.
num_heads
,
num_heads
=
self
.
num_heads
,
...
...
vllm/attention/layers/cross_attention.py
View file @
8d75f22e
...
@@ -103,7 +103,7 @@ def create_cross_attention_backend(
...
@@ -103,7 +103,7 @@ def create_cross_attention_backend(
# needed here to know how many tokens to attend to from the cached
# needed here to know how many tokens to attend to from the cached
# cross-attention KV cache.
# cross-attention KV cache.
new_metadata
.
seq_lens
=
common_attn_metadata
.
encoder_seq_lens
new_metadata
.
seq_lens
=
common_attn_metadata
.
encoder_seq_lens
new_metadata
.
seq_lens_cpu
=
torch
.
from_numpy
(
new_metadata
.
_
seq_lens_cpu
=
torch
.
from_numpy
(
common_attn_metadata
.
encoder_seq_lens_cpu
common_attn_metadata
.
encoder_seq_lens_cpu
)
)
...
...
Prev
1
…
13
14
15
16
17
18
19
20
21
…
33
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment