"benchmarks/vscode:/vscode.git/clone" did not exist on "a657bfc48a11d87de146629a7b6c03e9ccfbc3fc"
Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random import random
import sys
from typing import Any from typing import Any
import pytest import pytest
...@@ -10,7 +9,6 @@ from tests.utils import create_new_process_for_each_test ...@@ -10,7 +9,6 @@ from tests.utils import create_new_process_for_each_test
from tests.v1.logits_processors.utils import ( from tests.v1.logits_processors.utils import (
DUMMY_LOGITPROC_ARG, DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN, DUMMY_LOGITPROC_FQCN,
DUMMY_LOGITPROC_MODULE,
MAX_TOKENS, MAX_TOKENS,
MODEL_NAME, MODEL_NAME,
POOLING_MODEL_NAME, POOLING_MODEL_NAME,
...@@ -18,7 +16,6 @@ from tests.v1.logits_processors.utils import ( ...@@ -18,7 +16,6 @@ from tests.v1.logits_processors.utils import (
CustomLogitprocSource, CustomLogitprocSource,
DummyLogitsProcessor, DummyLogitsProcessor,
WrappedPerReqLogitsProcessor, WrappedPerReqLogitsProcessor,
dummy_module,
prompts, prompts,
) )
from tests.v1.logits_processors.utils import entry_points as fake_entry_points from tests.v1.logits_processors.utils import entry_points as fake_entry_points
...@@ -162,8 +159,6 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource ...@@ -162,8 +159,6 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource
kwargs: dict[str, list[str | type[LogitsProcessor]]] = {} kwargs: dict[str, list[str | type[LogitsProcessor]]] = {}
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
# Scenario: load logitproc based on fully-qualified class name (FQCN) # Scenario: load logitproc based on fully-qualified class name (FQCN)
# Inject dummy module which defines logitproc
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS: elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
# Scenario: load logitproc from provided class object # Scenario: load logitproc from provided class object
......
...@@ -14,11 +14,9 @@ from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_te ...@@ -14,11 +14,9 @@ from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_te
from tests.v1.logits_processors.utils import ( from tests.v1.logits_processors.utils import (
DUMMY_LOGITPROC_ARG, DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN, DUMMY_LOGITPROC_FQCN,
DUMMY_LOGITPROC_MODULE,
MAX_TOKENS, MAX_TOKENS,
MODEL_NAME, MODEL_NAME,
TEMP_GREEDY, TEMP_GREEDY,
dummy_module,
prompts, prompts,
) )
from tests.v1.logits_processors.utils import entry_points as fake_entry_points from tests.v1.logits_processors.utils import entry_points as fake_entry_points
...@@ -47,20 +45,14 @@ def _server_with_logitproc_entrypoint( ...@@ -47,20 +45,14 @@ def _server_with_logitproc_entrypoint(
main.main() main.main()
def _server_with_logitproc_module( def _server_with_logitproc_fqcn(
env_dict: dict[str, str] | None, env_dict: dict[str, str] | None,
model: str, model: str,
vllm_serve_args: list[str], vllm_serve_args: list[str],
) -> None: ) -> None:
"""Start vLLM server, inject module with dummy logitproc""" """Start vLLM server, inject module with dummy logitproc"""
# Patch `modules` to inject dummy logitproc module
from vllm.entrypoints.cli import main from vllm.entrypoints.cli import main
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
# fork is required for workers to see entrypoint patch
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork"
if env_dict is not None: if env_dict is not None:
os.environ.update(env_dict) os.environ.update(env_dict)
...@@ -99,7 +91,7 @@ def server(default_server_args, request, monkeypatch): ...@@ -99,7 +91,7 @@ def server(default_server_args, request, monkeypatch):
if request.param: if request.param:
# Launch server, append FQCN argument, inject dummy logitproc module # Launch server, append FQCN argument, inject dummy logitproc module
args = default_server_args + request.param args = default_server_args + request.param
_server_fxn = _server_with_logitproc_module _server_fxn = _server_with_logitproc_fqcn
else: else:
# Launch server, inject dummy logitproc entrypoint # Launch server, inject dummy logitproc entrypoint
args = default_server_args args = default_server_args
......
...@@ -27,7 +27,7 @@ DUMMY_LOGITPROC_ARG = "target_token" ...@@ -27,7 +27,7 @@ DUMMY_LOGITPROC_ARG = "target_token"
TEMP_GREEDY = 0.0 TEMP_GREEDY = 0.0
MAX_TOKENS = 20 MAX_TOKENS = 20
DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc" DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc"
DUMMY_LOGITPROC_MODULE = "DummyModule" DUMMY_LOGITPROC_MODULE = "tests.v1.logits_processors.utils"
DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor" DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.metrics.stats import IterationStats from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, RequestStateStats
def test_iteration_stats_repr(): def test_iteration_stats_repr():
iteration_stats = IterationStats() iteration_stats = IterationStats()
assert repr(iteration_stats).startswith("IterationStats(") assert repr(iteration_stats).startswith("IterationStats(")
def test_prefill_kv_computed_with_cache():
"""Test that prefill KV compute correctly excludes cached tokens."""
iteration_stats = IterationStats()
req_stats = RequestStateStats(arrival_time=0.0)
req_stats.scheduled_ts = 0.1
req_stats.first_token_ts = 0.5
req_stats.last_token_ts = 5.0
req_stats.num_generation_tokens = 50
# Case 1: With prefix cache (1200 tokens cached)
iteration_stats.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=10000,
max_tokens_param=100,
req_stats=req_stats,
num_cached_tokens=1200,
)
finished_req = iteration_stats.finished_requests[0]
assert finished_req.num_prompt_tokens == 10000
assert finished_req.num_cached_tokens == 1200
# Verify calculation: prefill KV = prompt tokens - cached tokens
prefill_kv_computed = finished_req.num_prompt_tokens - max(
finished_req.num_cached_tokens, 0
)
assert prefill_kv_computed == 8800 # 10000 - 1200
def test_prefill_kv_computed_no_cache():
"""Test prefill KV compute without prefix caching."""
iteration_stats = IterationStats()
req_stats = RequestStateStats(arrival_time=0.0)
req_stats.scheduled_ts = 0.1
req_stats.first_token_ts = 0.5
req_stats.last_token_ts = 2.0
req_stats.num_generation_tokens = 10
# Case 2: No prefix cache
iteration_stats.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=2000,
max_tokens_param=100,
req_stats=req_stats,
num_cached_tokens=0,
)
finished_req = iteration_stats.finished_requests[0]
assert finished_req.num_prompt_tokens == 2000
assert finished_req.num_cached_tokens == 0
# Verify calculation: prefill KV = full prompt when no cache
prefill_kv_computed = finished_req.num_prompt_tokens - max(
finished_req.num_cached_tokens, 0
)
assert prefill_kv_computed == 2000
def test_prefill_kv_computed_edge_cases():
"""Test edge cases for prefill KV compute calculation."""
iteration_stats = IterationStats()
req_stats = RequestStateStats(arrival_time=0.0)
req_stats.scheduled_ts = 0.1
req_stats.first_token_ts = 0.5
req_stats.last_token_ts = 1.0
req_stats.num_generation_tokens = 1
# Case 3: Negative num_cached_tokens (shouldn't happen, but handle gracefully)
iteration_stats.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=100,
max_tokens_param=10,
req_stats=req_stats,
num_cached_tokens=-1,
)
finished_req = iteration_stats.finished_requests[0]
# max() should handle negative values
prefill_kv_computed = finished_req.num_prompt_tokens - max(
finished_req.num_cached_tokens, 0
)
assert prefill_kv_computed == 100 # Should treat negative as 0
# Case 4: All tokens cached (shouldn't happen in practice)
iteration_stats2 = IterationStats()
iteration_stats2.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=100,
max_tokens_param=10,
req_stats=req_stats,
num_cached_tokens=100,
)
finished_req2 = iteration_stats2.finished_requests[0]
prefill_kv_computed2 = finished_req2.num_prompt_tokens - max(
finished_req2.num_cached_tokens, 0
)
assert prefill_kv_computed2 == 0 # All cached, nothing computed
...@@ -339,7 +339,7 @@ def test_load_model( ...@@ -339,7 +339,7 @@ def test_load_model(
"multi-token eagle spec decode on current platform" "multi-token eagle spec decode on current platform"
) )
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Setup draft model mock # Setup draft model mock
...@@ -436,7 +436,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): ...@@ -436,7 +436,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
"because it requires special input mocking." "because it requires special input mocking."
) )
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Use GPU device # Use GPU device
...@@ -543,6 +543,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): ...@@ -543,6 +543,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
attn_metadata_builder_cls, _ = try_get_attention_backend( attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN AttentionBackendEnum.TREE_ATTN
) )
elif attn_backend == "ROCM_AITER_FA":
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.ROCM_AITER_FA
)
else: else:
raise ValueError(f"Unsupported attention backend: {attn_backend}") raise ValueError(f"Unsupported attention backend: {attn_backend}")
......
...@@ -47,7 +47,7 @@ def test_eagle_max_len( ...@@ -47,7 +47,7 @@ def test_eagle_max_len(
"multi-token eagle spec decode on current platform" "multi-token eagle spec decode on current platform"
) )
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1") m.setenv("VLLM_ROCM_USE_AITER", "1")
llm = LLM( llm = LLM(
...@@ -82,7 +82,7 @@ def test_eagle_max_len( ...@@ -82,7 +82,7 @@ def test_eagle_max_len(
len(o.prompt_token_ids) len(o.prompt_token_ids)
< 80 < 80
< len(o.prompt_token_ids) + len(o.outputs[0].token_ids) < len(o.prompt_token_ids) + len(o.outputs[0].token_ids)
< 200 <= 200
), ( ), (
"This test is only meaningful if the output " "This test is only meaningful if the output "
"is longer than the eagle max length" "is longer than the eagle max length"
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from vllm.config import SpeculativeConfig from vllm.config import SpeculativeConfig
from vllm.model_executor.models.interfaces import supports_eagle3 from vllm.model_executor.models.interfaces import supports_eagle3
from vllm.platforms import current_platform
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -21,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3 ...@@ -21,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3
pytest.param( pytest.param(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
id="qwen3-eagle3-speculator-w4a16-verifier", id="qwen3-eagle3-speculator-w4a16-verifier",
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="The tests are skipped on rocm platform.",
),
), ),
], ],
) )
......
...@@ -88,8 +88,8 @@ def forward_attention( ...@@ -88,8 +88,8 @@ def forward_attention(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc.cpu(), query_start_loc_cpu=query_start_loc.cpu(),
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_cpu=seq_lens.cpu(), _seq_lens_cpu=seq_lens.cpu(),
num_computed_tokens_cpu=context_lens.cpu(), _num_computed_tokens_cpu=context_lens.cpu(),
num_reqs=batch_size, num_reqs=batch_size,
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from concurrent.futures import Future
import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.config import StructuredOutputsConfig, VllmConfig from vllm.config import StructuredOutputsConfig, VllmConfig
from vllm.config.model import ModelConfig from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.speculative import SpeculativeConfig from vllm.config.speculative import SpeculativeConfig
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -116,3 +121,72 @@ def test_grammar_bitmask_with_specdec(): ...@@ -116,3 +121,72 @@ def test_grammar_bitmask_with_specdec():
) # EOS not the final token ) # EOS not the final token
grammar_bitmask(request, prompt[i:]) # EOS not present grammar_bitmask(request, prompt[i:]) # EOS not present
grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id]) grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id])
@pytest.mark.parametrize("async_grammar", [True, False])
def test_grammar_init_async_and_sync(async_grammar):
"""Test grammar initialization works correctly in both async and sync modes.
This test validates that the distributed_executor_backend config option
correctly controls whether grammar compilation happens asynchronously
(via executor.submit) or synchronously. When set to "external_launcher",
grammar compilation is synchronous to avoid deadlocks.
"""
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
prompt = tokenizer.encode('{"a": "b"}')
# Use "external_launcher" for sync mode, None for async mode
executor_backend = None if async_grammar else "external_launcher"
vllm_config = VllmConfig(
model_config=ModelConfig(tokenizer=TOKENIZER),
structured_outputs_config=StructuredOutputsConfig(backend="guidance"),
parallel_config=ParallelConfig(distributed_executor_backend=executor_backend),
)
structured_output_manager = StructuredOutputManager(vllm_config)
sampling_params = SamplingParams(
structured_outputs=StructuredOutputsParams(
json='{"type": "object"}',
),
)
sampling_params.structured_outputs._backend = "guidance"
request = Request(
"test_request",
prompt_token_ids=prompt,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=tokenizer.eos_token_id,
)
structured_output_manager.grammar_init(request)
# Check the internal _grammar type immediately after init
# Before _check_grammar_completion is called, async mode should have a Future
raw_grammar = request.structured_output_request._grammar
if async_grammar:
assert isinstance(raw_grammar, Future), (
"Async mode should store a Future before completion"
)
else:
assert not isinstance(raw_grammar, Future), (
"Sync mode should store the grammar directly, not a Future"
)
# Wait for grammar to be ready (handles both async and sync cases)
start_time = time.time()
while not request.structured_output_request._check_grammar_completion():
if time.time() - start_time > 5: # 5-second timeout
pytest.fail("Grammar compilation timed out")
time.sleep(0.01)
# After completion, _grammar should no longer be a Future
assert not isinstance(request.structured_output_request._grammar, Future)
# Verify grammar is properly initialized and functional
grammar = request.structured_output_request.grammar
assert grammar is not None
assert not grammar.is_terminated()
# Verify the grammar can accept valid tokens
assert grammar.accept_tokens(request.request_id, prompt)
...@@ -70,6 +70,7 @@ class TestReasoningStructuredOutput: ...@@ -70,6 +70,7 @@ class TestReasoningStructuredOutput:
request.use_structured_output = True request.use_structured_output = True
request.prompt_token_ids = [1, 2, 3, 4, 5] request.prompt_token_ids = [1, 2, 3, 4, 5]
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8] request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
request.num_computed_tokens = 5
return request return request
def test_should_fill_bitmask_with_enable_in_reasoning( def test_should_fill_bitmask_with_enable_in_reasoning(
......
...@@ -104,22 +104,31 @@ class MyRequest(msgspec.Struct): ...@@ -104,22 +104,31 @@ class MyRequest(msgspec.Struct):
def test_multimodal_kwargs(): def test_multimodal_kwargs():
e1 = MultiModalFieldElem( e1 = MultiModalFieldElem(
"audio", "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField() "audio",
"a0",
torch.zeros(1000, dtype=torch.bfloat16),
MultiModalBatchedField(),
) )
e2 = MultiModalFieldElem( e2 = MultiModalFieldElem(
"video", "video",
"v0", "v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)], [torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalFlatField([[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0), MultiModalFlatField(
slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]],
dim=0,
),
) )
e3 = MultiModalFieldElem( e3 = MultiModalFieldElem(
"image", "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(4) "image",
"i0",
torch.zeros(1000, dtype=torch.int32),
MultiModalSharedField(batch_size=4),
) )
e4 = MultiModalFieldElem( e4 = MultiModalFieldElem(
"image", "image",
"i1", "i1",
torch.zeros(1000, dtype=torch.int32), torch.zeros(1000, dtype=torch.int32),
MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2), MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2),
) )
audio = MultiModalKwargsItem.from_elems([e1]) audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2]) video = MultiModalKwargsItem.from_elems([e2])
...@@ -138,8 +147,8 @@ def test_multimodal_kwargs(): ...@@ -138,8 +147,8 @@ def test_multimodal_kwargs():
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14306, +-20 for minor changes # expected total encoding length, should be 14395, +-20 for minor changes
assert 14275 <= total_len <= 14325 assert 14375 <= total_len <= 14425
decoded = decoder.decode(encoded).mm[0] decoded = decoder.decode(encoded).mm[0]
assert isinstance(decoded, MultiModalKwargsItems) assert isinstance(decoded, MultiModalKwargsItems)
......
...@@ -6,8 +6,10 @@ import pytest ...@@ -6,8 +6,10 @@ import pytest
import torch import torch
from vllm.attention.backends.abstract import MultipleOf from vllm.attention.backends.abstract import MultipleOf
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import ( from vllm.config import (
AttentionConfig,
CacheConfig, CacheConfig,
ModelConfig, ModelConfig,
ParallelConfig, ParallelConfig,
...@@ -761,7 +763,11 @@ def test_init_kv_cache_with_kv_sharing_valid(): ...@@ -761,7 +763,11 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[1] == layer_1 assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[1] == layer_1
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): @pytest.mark.skipif(
current_platform.is_rocm(),
reason="Attention backend FLASHINFER is not supported on ROCm.",
)
def test_hybrid_attention_mamba_tensor_shapes():
""" """
The GPU model runner creates different views into the The GPU model runner creates different views into the
KVCacheTensors for the attention and mamba layers KVCacheTensors for the attention and mamba layers
...@@ -802,11 +808,13 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): ...@@ -802,11 +808,13 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
cache_dtype="auto", cache_dtype="auto",
) )
parallel_config = ParallelConfig() parallel_config = ParallelConfig()
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASHINFER)
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
parallel_config=parallel_config, parallel_config=parallel_config,
attention_config=attention_config,
) )
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
...@@ -816,8 +824,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): ...@@ -816,8 +824,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
layer_4 = "model.layers.4.mixer" layer_4 = "model.layers.4.mixer"
layer_5 = "model.layers.5.mixer" layer_5 = "model.layers.5.mixer"
with set_current_vllm_config(vllm_config), monkeypatch.context() as m: with set_current_vllm_config(vllm_config):
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
hf_config = vllm_config.model_config.hf_config hf_config = vllm_config.model_config.hf_config
fwd_context = {} fwd_context = {}
for key in [layer_0, layer_1]: for key in [layer_0, layer_1]:
...@@ -847,10 +854,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): ...@@ -847,10 +854,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
) )
# suppress var not used error # suppress var not used error
assert fwd_context is not None assert fwd_context is not None
vllm_ctx = vllm_config.compilation_config.static_forward_context vllm_ctx = vllm_config.compilation_config.static_forward_context
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
runner = GPUModelRunner(vllm_config, DEVICE) runner = GPUModelRunner(vllm_config, DEVICE)
kv_cache_spec = runner.get_kv_cache_spec() kv_cache_spec = runner.get_kv_cache_spec()
...@@ -861,94 +865,94 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): ...@@ -861,94 +865,94 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
)[0] )[0]
runner.initialize_kv_cache(kv_cache_config) runner.initialize_kv_cache(kv_cache_config)
# random partition of blocks # random partition of blocks
# blocks0 will be assigned to attention layers # blocks0 will be assigned to attention layers
# blocks1 will be assigned to mamba layers # blocks1 will be assigned to mamba layers
num_blocks = kv_cache_config.num_blocks num_blocks = kv_cache_config.num_blocks
ind = np.arange(num_blocks) ind = np.arange(num_blocks)
np.random.shuffle(ind) np.random.shuffle(ind)
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :] blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
# assert we are using FlashInfer # assert we are using FlashInfer
assert attn_shape[0] % num_blocks == 0 assert attn_shape[0] % num_blocks == 0
block_split_ratio = attn_shape[0] // num_blocks block_split_ratio = attn_shape[0] // num_blocks
# use small blocks for testing to avoid memory issues # use small blocks for testing to avoid memory issues
test_block_size = min(2, len(blocks0), len(blocks1)) test_block_size = min(2, len(blocks0), len(blocks1))
# use non-overlapping blocks to avoid data contamination # use non-overlapping blocks to avoid data contamination
# Split kernel blocks: first half for attention, second half for mamba # Split kernel blocks: first half for attention, second half for mamba
mid_point = num_blocks // 2 mid_point = num_blocks // 2
# attention uses kernel blocks from first half (mapped to logical blocks) # attention uses kernel blocks from first half (mapped to logical blocks)
kv_blocks_for_attention = np.array([0, 1])[:test_block_size] kv_blocks_for_attention = np.array([0, 1])[:test_block_size]
# mamba uses kernel blocks from second half # mamba uses kernel blocks from second half
kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size] kv_blocks_for_mamba = np.array([mid_point, mid_point + 1])[:test_block_size]
# create small constant tensors for testing with corrected shapes # create small constant tensors for testing with corrected shapes
# attention: [block_size, ...] starting from dimension 2 # attention: [block_size, ...] starting from dimension 2
attn_constant_shape = attn_shape[2:] attn_constant_shape = attn_shape[2:]
conv_constant_shape = conv_shape[1:] conv_constant_shape = conv_shape[1:]
ssm_constant_shape = ssm_shape[1:] ssm_constant_shape = ssm_shape[1:]
attn_blocks_constant = torch.full( attn_blocks_constant = torch.full(
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33 (test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
) )
conv_blocks_constant = torch.full( conv_blocks_constant = torch.full(
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66 (test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
) )
ssm_blocks_constant = torch.full( ssm_blocks_constant = torch.full(
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99 (test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
) )
# Fill attention blocks with constants using kv block indices # Fill attention blocks with constants using kv block indices
kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio kernel_blocks_for_attention = kv_blocks_for_attention * block_split_ratio
for layer in [layer_0, layer_1]: for layer in [layer_0, layer_1]:
# attention: kv_cache[0][kernel_block_idx, kv_idx, ...] # attention: kv_cache[0][kernel_block_idx, kv_idx, ...]
for i, kernel_block in enumerate(kernel_blocks_for_attention): for i, kernel_block in enumerate(kernel_blocks_for_attention):
vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i] vllm_ctx[layer].kv_cache[0][kernel_block, :] = attn_blocks_constant[i]
# fill mamba blocks with constants using kernel block indices # fill mamba blocks with constants using kernel block indices
for layer in [layer_2, layer_3, layer_4, layer_5]: for layer in [layer_2, layer_3, layer_4, layer_5]:
# mamba: kv_cache[0][component][kernel_block_idx, ...] # mamba: kv_cache[0][component][kernel_block_idx, ...]
for i, kv_block in enumerate(kv_blocks_for_mamba): for i, kv_block in enumerate(kv_blocks_for_mamba):
vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i] vllm_ctx[layer].kv_cache[0][0][kv_block, :] = conv_blocks_constant[i]
vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i] vllm_ctx[layer].kv_cache[0][1][kv_block, :] = ssm_blocks_constant[i]
# verify attention and mamba contents are correct # verify attention and mamba contents are correct
for layer in [layer_0, layer_1]: for layer in [layer_0, layer_1]:
for i, kernel_block in enumerate(kernel_blocks_for_attention): for i, kernel_block in enumerate(kernel_blocks_for_attention):
actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :] actual_kv = vllm_ctx[layer].kv_cache[0][kernel_block, :]
expected = attn_blocks_constant[i] expected = attn_blocks_constant[i]
# Check K and V separately # Check K and V separately
assert torch.equal(actual_kv[0], expected) assert torch.equal(actual_kv[0], expected)
assert torch.equal(actual_kv[1], expected) assert torch.equal(actual_kv[1], expected)
for layer in [layer_2, layer_3, layer_4, layer_5]: for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba): for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i] expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i] expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv) assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm) assert torch.equal(actual_ssm, expected_ssm)
for layer in [layer_2, layer_3, layer_4, layer_5]: for layer in [layer_2, layer_3, layer_4, layer_5]:
for i, kv_block in enumerate(kv_blocks_for_mamba): for i, kv_block in enumerate(kv_blocks_for_mamba):
actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :] actual_conv = vllm_ctx[layer].kv_cache[0][0][kv_block, :]
actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :] actual_ssm = vllm_ctx[layer].kv_cache[0][1][kv_block, :]
expected_conv = conv_blocks_constant[i] expected_conv = conv_blocks_constant[i]
expected_ssm = ssm_blocks_constant[i] expected_ssm = ssm_blocks_constant[i]
assert torch.equal(actual_conv, expected_conv) assert torch.equal(actual_conv, expected_conv)
assert torch.equal(actual_ssm, expected_ssm) assert torch.equal(actual_ssm, expected_ssm)
def test_hybrid_block_table_initialization(): def test_hybrid_block_table_initialization():
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
import vllm.envs as envs from vllm.config import ProfilerConfig
from vllm.profiler.gpu_profiler import WorkerProfiler from vllm.profiler.wrapper import WorkerProfiler
class ConcreteWorkerProfiler(WorkerProfiler): class ConcreteWorkerProfiler(WorkerProfiler):
...@@ -11,11 +11,11 @@ class ConcreteWorkerProfiler(WorkerProfiler): ...@@ -11,11 +11,11 @@ class ConcreteWorkerProfiler(WorkerProfiler):
A basic implementation of a worker profiler for testing purposes. A basic implementation of a worker profiler for testing purposes.
""" """
def __init__(self): def __init__(self, profiler_config: ProfilerConfig):
self.start_call_count = 0 self.start_call_count = 0
self.stop_call_count = 0 self.stop_call_count = 0
self.should_fail_start = False self.should_fail_start = False
super().__init__() super().__init__(profiler_config)
def _start(self) -> None: def _start(self) -> None:
if self.should_fail_start: if self.should_fail_start:
...@@ -26,17 +26,19 @@ class ConcreteWorkerProfiler(WorkerProfiler): ...@@ -26,17 +26,19 @@ class ConcreteWorkerProfiler(WorkerProfiler):
self.stop_call_count += 1 self.stop_call_count += 1
@pytest.fixture(autouse=True) @pytest.fixture
def reset_mocks(): def default_profiler_config():
"""Fixture to reset mocks and env variables before each test.""" return ProfilerConfig(
envs.VLLM_PROFILER_DELAY_ITERS = 0 profiler="torch",
envs.VLLM_PROFILER_MAX_ITERS = 0 torch_profiler_dir="/tmp/mock",
delay_iterations=0,
max_iterations=0,
)
def test_immediate_start_stop(): def test_immediate_start_stop(default_profiler_config):
"""Test standard start without delay.""" """Test standard start without delay."""
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start() profiler.start()
assert profiler._running is True assert profiler._running is True
assert profiler._active is True assert profiler._active is True
...@@ -48,10 +50,10 @@ def test_immediate_start_stop(): ...@@ -48,10 +50,10 @@ def test_immediate_start_stop():
assert profiler.stop_call_count == 1 assert profiler.stop_call_count == 1
def test_delayed_start(): def test_delayed_start(default_profiler_config):
"""Test that profiler waits for N steps before actually starting.""" """Test that profiler waits for N steps before actually starting."""
envs.VLLM_PROFILER_DELAY_ITERS = 2 default_profiler_config.delay_iterations = 2
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
# User requests start # User requests start
profiler.start() profiler.start()
...@@ -71,10 +73,10 @@ def test_delayed_start(): ...@@ -71,10 +73,10 @@ def test_delayed_start():
assert profiler.start_call_count == 1 assert profiler.start_call_count == 1
def test_max_iterations(): def test_max_iterations(default_profiler_config):
"""Test that profiler stops automatically after max iterations.""" """Test that profiler stops automatically after max iterations."""
envs.VLLM_PROFILER_MAX_ITERS = 2 default_profiler_config.max_iterations = 2
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start() profiler.start()
assert profiler._running is True assert profiler._running is True
...@@ -95,12 +97,11 @@ def test_max_iterations(): ...@@ -95,12 +97,11 @@ def test_max_iterations():
assert profiler.stop_call_count == 1 assert profiler.stop_call_count == 1
def test_delayed_start_and_max_iters(): def test_delayed_start_and_max_iters(default_profiler_config):
"""Test combined delayed start and max iterations.""" """Test combined delayed start and max iterations."""
envs.VLLM_PROFILER_DELAY_ITERS = 2 default_profiler_config.delay_iterations = 2
envs.VLLM_PROFILER_MAX_ITERS = 2 default_profiler_config.max_iterations = 2
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start() profiler.start()
# Step 1 # Step 1
...@@ -127,9 +128,9 @@ def test_delayed_start_and_max_iters(): ...@@ -127,9 +128,9 @@ def test_delayed_start_and_max_iters():
assert profiler.stop_call_count == 1 assert profiler.stop_call_count == 1
def test_idempotency(): def test_idempotency(default_profiler_config):
"""Test that calling start/stop multiple times doesn't break logic.""" """Test that calling start/stop multiple times doesn't break logic."""
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
# Double Start # Double Start
profiler.start() profiler.start()
...@@ -142,10 +143,10 @@ def test_idempotency(): ...@@ -142,10 +143,10 @@ def test_idempotency():
assert profiler.stop_call_count == 1 # Should only stop once assert profiler.stop_call_count == 1 # Should only stop once
def test_step_inactive(): def test_step_inactive(default_profiler_config):
"""Test that stepping while inactive does nothing.""" """Test that stepping while inactive does nothing."""
envs.VLLM_PROFILER_DELAY_ITERS = 2 default_profiler_config.delay_iterations = 2
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
# Not started yet # Not started yet
profiler.step() profiler.step()
...@@ -155,9 +156,9 @@ def test_step_inactive(): ...@@ -155,9 +156,9 @@ def test_step_inactive():
assert profiler.start_call_count == 0 assert profiler.start_call_count == 0
def test_start_failure(): def test_start_failure(default_profiler_config):
"""Test behavior when the underlying _start method raises exception.""" """Test behavior when the underlying _start method raises exception."""
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.should_fail_start = True profiler.should_fail_start = True
profiler.start() profiler.start()
...@@ -168,9 +169,9 @@ def test_start_failure(): ...@@ -168,9 +169,9 @@ def test_start_failure():
assert profiler.start_call_count == 0 # Logic failed inside start assert profiler.start_call_count == 0 # Logic failed inside start
def test_shutdown(): def test_shutdown(default_profiler_config):
"""Test that shutdown calls stop only if running.""" """Test that shutdown calls stop only if running."""
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
# Case 1: Not running # Case 1: Not running
profiler.shutdown() profiler.shutdown()
...@@ -182,10 +183,10 @@ def test_shutdown(): ...@@ -182,10 +183,10 @@ def test_shutdown():
assert profiler.stop_call_count == 1 assert profiler.stop_call_count == 1
def test_mixed_delay_and_stop(): def test_mixed_delay_and_stop(default_profiler_config):
"""Test manual stop during the delay period.""" """Test manual stop during the delay period."""
envs.VLLM_PROFILER_DELAY_ITERS = 5 default_profiler_config.delay_iterations = 5
profiler = ConcreteWorkerProfiler() profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start() profiler.start()
profiler.step() profiler.step()
......
...@@ -10,9 +10,10 @@ set -ex ...@@ -10,9 +10,10 @@ set -ex
CUDA_HOME=${CUDA_HOME:-/usr/local/cuda} CUDA_HOME=${CUDA_HOME:-/usr/local/cuda}
PPLX_COMMIT_HASH=${PPLX_COMMIT_HASH:-"12cecfd"} PPLX_COMMIT_HASH=${PPLX_COMMIT_HASH:-"12cecfd"}
DEEPEP_COMMIT_HASH=${DEEPEP_COMMIT_HASH:-"73b6ea4"} DEEPEP_COMMIT_HASH=${DEEPEP_COMMIT_HASH:-"73b6ea4"}
NVSHMEM_VER=3.3.9 NVSHMEM_VER=3.3.24 # Suppports both CUDA 12 and 13
WORKSPACE=${WORKSPACE:-$(pwd)/ep_kernels_workspace} WORKSPACE=${WORKSPACE:-$(pwd)/ep_kernels_workspace}
MODE=${MODE:-install} MODE=${MODE:-install}
CUDA_VERSION_MAJOR=$(${CUDA_HOME}/bin/nvcc --version | egrep -o "release [0-9]+" | cut -d ' ' -f 2)
# Parse arguments # Parse arguments
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
...@@ -75,11 +76,9 @@ ARCH=$(uname -m) ...@@ -75,11 +76,9 @@ ARCH=$(uname -m)
case "${ARCH,,}" in case "${ARCH,,}" in
x86_64|amd64) x86_64|amd64)
NVSHMEM_SUBDIR="linux-x86_64" NVSHMEM_SUBDIR="linux-x86_64"
NVSHMEM_FILE="libnvshmem-linux-x86_64-${NVSHMEM_VER}_cuda12-archive.tar.xz"
;; ;;
aarch64|arm64) aarch64|arm64)
NVSHMEM_SUBDIR="linux-sbsa" NVSHMEM_SUBDIR="linux-sbsa"
NVSHMEM_FILE="libnvshmem-linux-sbsa-${NVSHMEM_VER}_cuda12-archive.tar.xz"
;; ;;
*) *)
echo "Unsupported architecture: ${ARCH}" >&2 echo "Unsupported architecture: ${ARCH}" >&2
...@@ -87,6 +86,7 @@ case "${ARCH,,}" in ...@@ -87,6 +86,7 @@ case "${ARCH,,}" in
;; ;;
esac esac
NVSHMEM_FILE="libnvshmem-${NVSHMEM_SUBDIR}-${NVSHMEM_VER}_cuda${CUDA_VERSION_MAJOR}-archive.tar.xz"
NVSHMEM_URL="https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/${NVSHMEM_SUBDIR}/${NVSHMEM_FILE}" NVSHMEM_URL="https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/${NVSHMEM_SUBDIR}/${NVSHMEM_FILE}"
pushd "$WORKSPACE" pushd "$WORKSPACE"
...@@ -142,13 +142,6 @@ clone_repo() { ...@@ -142,13 +142,6 @@ clone_repo() {
fi fi
} }
deepep_cuda13_patch() {
cuda_version_major=$(${CUDA_HOME}/bin/nvcc --version | egrep -o "release [0-9]+" | cut -d ' ' -f 2)
if [ ${cuda_version_major} -ge 13 ]; then
sed -i "s|f'{nvshmem_dir}/include']|f'{nvshmem_dir}/include', '${CUDA_HOME}/include/cccl']|" "setup.py"
fi
}
do_build() { do_build() {
local repo=$1 local repo=$1
local name=$2 local name=$2
...@@ -160,8 +153,9 @@ do_build() { ...@@ -160,8 +153,9 @@ do_build() {
clone_repo "$repo" "$name" "$key" "$commit" clone_repo "$repo" "$name" "$key" "$commit"
cd "$name" cd "$name"
if [ "$name" == "DeepEP" ]; then # DeepEP CUDA 13 patch
deepep_cuda13_patch if [[ "$name" == "DeepEP" && "${CUDA_VERSION_MAJOR}" -ge 13 ]]; then
sed -i "s|f'{nvshmem_dir}/include']|f'{nvshmem_dir}/include', '${CUDA_HOME}/include/cccl']|" "setup.py"
fi fi
if [ "$MODE" = "install" ]; then if [ "$MODE" = "install" ]; then
......
...@@ -3,9 +3,7 @@ ...@@ -3,9 +3,7 @@
import glob import glob
requires_files = glob.glob("requirements/*.txt") for file in (*glob.glob("requirements/*.txt"), "pyproject.toml"):
requires_files += ["pyproject.toml"]
for file in requires_files:
print(f">>> cleaning {file}") print(f">>> cleaning {file}")
with open(file) as f: with open(file) as f:
lines = f.readlines() lines = f.readlines()
...@@ -17,5 +15,4 @@ for file in requires_files: ...@@ -17,5 +15,4 @@ for file in requires_files:
f.write(line) f.write(line)
else: else:
print(line.strip()) print(line.strip())
print(f"<<< done cleaning {file}") print(f"<<< done cleaning {file}\n")
print()
...@@ -9,6 +9,8 @@ import vllm.envs as envs ...@@ -9,6 +9,8 @@ import vllm.envs as envs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
_FP8_DTYPE = current_platform.fp8_dtype()
def is_aiter_found() -> bool: def is_aiter_found() -> bool:
from importlib.util import find_spec from importlib.util import find_spec
...@@ -22,6 +24,15 @@ def is_aiter_found() -> bool: ...@@ -22,6 +24,15 @@ def is_aiter_found() -> bool:
# we keep this global outside to not cause torch compile breaks. # we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND = is_aiter_found() IS_AITER_FOUND = is_aiter_found()
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if IS_AITER_FOUND:
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
def if_aiter_supported(func: Callable) -> Callable: def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if """Decorator that only executes the function if
...@@ -43,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable: ...@@ -43,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable:
return wrapper return wrapper
def _rocm_aiter_group_fp8_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
from aiter import QuantType, dtypes, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8)
def _rocm_aiter_group_fp8_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter import dtypes
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
def _rocm_aiter_fused_moe_impl( def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
...@@ -283,6 +264,28 @@ def _rocm_aiter_grouped_topk_fake( ...@@ -283,6 +264,28 @@ def _rocm_aiter_grouped_topk_fake(
pass pass
# Cache whether aiter supports FP8 MLA parameters
_AITER_MLA_SUPPORTS_FP8: bool | None = None
def _check_aiter_mla_fp8_support() -> bool:
"""Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters."""
global _AITER_MLA_SUPPORTS_FP8
if _AITER_MLA_SUPPORTS_FP8 is None:
try:
import inspect
from aiter.mla import mla_decode_fwd
sig = inspect.signature(mla_decode_fwd)
_AITER_MLA_SUPPORTS_FP8 = (
"q_scale" in sig.parameters and "kv_scale" in sig.parameters
)
except Exception:
_AITER_MLA_SUPPORTS_FP8 = False
return _AITER_MLA_SUPPORTS_FP8
def _rocm_aiter_mla_decode_fwd_impl( def _rocm_aiter_mla_decode_fwd_impl(
q: torch.Tensor, q: torch.Tensor,
kv_buffer: torch.Tensor, kv_buffer: torch.Tensor,
...@@ -299,6 +302,16 @@ def _rocm_aiter_mla_decode_fwd_impl( ...@@ -299,6 +302,16 @@ def _rocm_aiter_mla_decode_fwd_impl(
) -> None: ) -> None:
from aiter.mla import mla_decode_fwd from aiter.mla import mla_decode_fwd
kwargs = {
"sm_scale": sm_scale,
"logit_cap": logit_cap,
}
# Only pass q_scale and kv_scale if the aiter library supports them
if _check_aiter_mla_fp8_support():
kwargs["q_scale"] = q_scale
kwargs["kv_scale"] = kv_scale
mla_decode_fwd( mla_decode_fwd(
q, q,
kv_buffer.view(-1, 1, 1, q.shape[-1]), kv_buffer.view(-1, 1, 1, q.shape[-1]),
...@@ -308,10 +321,7 @@ def _rocm_aiter_mla_decode_fwd_impl( ...@@ -308,10 +321,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
kv_indices, kv_indices,
kv_last_page_lens, kv_last_page_lens,
max_seqlen_qo, max_seqlen_qo,
sm_scale=sm_scale, **kwargs,
logit_cap=logit_cap,
q_scale=q_scale,
kv_scale=kv_scale,
) )
...@@ -438,6 +448,195 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( ...@@ -438,6 +448,195 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
return torch.empty_like(x), torch.empty_like(residual) return torch.empty_like(x), torch.empty_like(residual)
def _rocm_aiter_per_tensor_quant_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.quant import per_tensor_quant_hip
return per_tensor_quant_hip(x, scale, quant_dtype)
def _rocm_aiter_per_tensor_quant_fake(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x, dtype=quant_dtype), torch.empty(
1, dtype=torch.float32, device=x.device
)
def _rocm_aiter_per_token_quant_impl(
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.quant import dynamic_per_token_scaled_quant
assert quant_dtype in [torch.int8, _FP8_DTYPE]
out_shape = x.shape
out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device)
if scale is None:
scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
dynamic_per_token_scaled_quant(
out,
x,
scale,
scale_ub=None,
shuffle_scale=False,
num_rows=None,
num_rows_factor=1,
)
return out, scale
def _rocm_aiter_per_token_quant_fake(
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
out_shape = x.shape
return (
torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device),
torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
)
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
x,
weight,
variance_epsilon,
None,
None,
None,
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
res1=residual,
)
return (x_quant, x_quant_scales, res)
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
torch.empty_like(residual, device=residual.device),
)
def _rocm_aiter_rmsnorm_fp8_group_quant_impl(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
x,
weight,
variance_epsilon,
None,
None,
None,
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
res1=None,
)
return (x_quant, x_quant_scales)
def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
)
def _rocm_aiter_group_fp8_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
from aiter import QuantType, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE)
def _rocm_aiter_group_fp8_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant
return act_mul_and_fp8_group_quant(
x,
activation="silu",
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
)
def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
assert N % 2 == 0
N_half = N // 2
x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device)
out_bs = torch.empty(
(
M,
(N_half + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
# Global flag to ensure ops are registered only once # Global flag to ensure ops are registered only once
_OPS_REGISTERED = False _OPS_REGISTERED = False
...@@ -473,7 +672,7 @@ class rocm_aiter_ops: ...@@ -473,7 +672,7 @@ class rocm_aiter_ops:
@if_aiter_supported @if_aiter_supported
def is_linear_fp8_enaled(cls) -> bool: def is_linear_fp8_enaled(cls) -> bool:
""" "Verifies device specs and availability of env variable.""" """ "Verifies device specs and availability of env variable."""
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz() return cls.is_linear_enabled()
@classmethod @classmethod
@if_aiter_supported @if_aiter_supported
...@@ -548,14 +747,6 @@ class rocm_aiter_ops: ...@@ -548,14 +747,6 @@ class rocm_aiter_ops:
) )
# register all the custom ops here # register all the custom ops here
direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_group_fp8_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1", op_name="rocm_aiter_asm_moe_tkw1",
op_func=_rocm_aiter_asm_moe_tkw1_impl, op_func=_rocm_aiter_asm_moe_tkw1_impl,
...@@ -615,27 +806,62 @@ class rocm_aiter_ops: ...@@ -615,27 +806,62 @@ class rocm_aiter_ops:
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_gemm_a8w8_blockscale", op_name="rocm_aiter_gemm_a8w8_blockscale",
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl, op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake, fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_rms_norm", op_name="rocm_aiter_rms_norm",
op_func=_rocm_aiter_rms_norm_impl, op_func=_rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rms_norm_fake, fake_impl=_rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add", op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl, op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake, fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_fp8_group_quant",
op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant",
op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl,
fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl,
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
fake_impl=_rocm_aiter_group_fp8_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_per_tensor_quant",
op_func=_rocm_aiter_per_tensor_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_per_tensor_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_per_token_quant",
op_func=_rocm_aiter_per_token_quant_impl,
mutates_args=["scale"],
fake_impl=_rocm_aiter_per_token_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
_OPS_REGISTERED = True _OPS_REGISTERED = True
@staticmethod @staticmethod
...@@ -830,6 +1056,22 @@ class rocm_aiter_ops: ...@@ -830,6 +1056,22 @@ class rocm_aiter_ops:
kv_scale=kv_scale, kv_scale=kv_scale,
) )
@staticmethod
def per_tensor_quant(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_per_tensor_quant(x, quant_dtype, scale)
@staticmethod
def per_token_quant(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale)
@staticmethod @staticmethod
def triton_fp4_gemm_dynamic_qaunt( def triton_fp4_gemm_dynamic_qaunt(
x: torch.Tensor, x: torch.Tensor,
......
...@@ -441,6 +441,46 @@ def rms_norm_dynamic_per_token_quant( ...@@ -441,6 +441,46 @@ def rms_norm_dynamic_per_token_quant(
return output, scales return output, scales
# fused quant layer norm ops blocked
def rms_norm_per_block_quant(
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
group_size: list[int],
scale_ub: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
is_scale_transposed: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert len(group_size) == 2
output = torch.empty_like(input, dtype=quant_dtype)
if is_scale_transposed:
scales = torch.empty(
(input.shape[-1] // group_size[1], input.numel() // input.shape[-1]),
device=input.device,
dtype=torch.float32,
).transpose(0, 1)
else:
scales = torch.empty(
(input.numel() // input.shape[-1], input.shape[-1] // group_size[1]),
device=input.device,
dtype=torch.float32,
)
torch.ops._C.rms_norm_per_block_quant(
output,
input,
weight,
scales,
epsilon,
scale_ub,
residual,
group_size[1],
is_scale_transposed,
)
return output, scales
# quantization ops # quantization ops
# awq # awq
def awq_dequantize( def awq_dequantize(
...@@ -660,6 +700,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -660,6 +700,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor:
return torch.empty_like(b, memory_format=torch.contiguous_format) return torch.empty_like(b, memory_format=torch.contiguous_format)
@register_fake("_C::cutlass_encode_and_reorder_int4b_grouped")
def cutlass_encode_and_reorder_int4b_grouped_fake(b: torch.Tensor) -> torch.Tensor:
return torch.empty_like(b, memory_format=torch.contiguous_format)
if hasattr(torch.ops._C, "allspark_w8a16_gemm"): if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
...@@ -1023,6 +1067,7 @@ def get_cutlass_moe_mm_problem_sizes( ...@@ -1023,6 +1067,7 @@ def get_cutlass_moe_mm_problem_sizes(
n: int, n: int,
k: int, k: int,
blockscale_offsets: torch.Tensor | None = None, blockscale_offsets: torch.Tensor | None = None,
force_swap_ab: bool | None = None,
): ):
""" """
Compute only the per-expert problem sizes needed by the two grouped matrix Compute only the per-expert problem sizes needed by the two grouped matrix
...@@ -1032,9 +1077,20 @@ def get_cutlass_moe_mm_problem_sizes( ...@@ -1032,9 +1077,20 @@ def get_cutlass_moe_mm_problem_sizes(
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's - problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
multiplication for the two grouped MMs multiplication for the two grouped MMs
used in the fused MoE operation. used in the fused MoE operation.
Optional:
- force_swap_ab: If set to True or False, explicitly enable or disable the
A/B input swap optimization. If None (default), the swap
is selected automatically based on tensor sizes.
""" """
return torch.ops._C.get_cutlass_moe_mm_problem_sizes( return torch.ops._C.get_cutlass_moe_mm_problem_sizes(
topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, blockscale_offsets topk_ids,
problem_sizes1,
problem_sizes2,
num_experts,
n,
k,
blockscale_offsets,
force_swap_ab,
) )
...@@ -1422,6 +1478,78 @@ def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor: ...@@ -1422,6 +1478,78 @@ def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor:
return torch.ops._C.cutlass_encode_and_reorder_int4b(b) return torch.ops._C.cutlass_encode_and_reorder_int4b(b)
def cutlass_w4a8_moe_mm(
out_tensors: torch.Tensor,
a_tensors: torch.Tensor,
b_tensors: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
b_group_scales: torch.Tensor,
b_group_size: int,
expert_offsets: torch.Tensor,
problem_sizes: torch.Tensor,
a_strides: torch.Tensor,
b_strides: torch.Tensor,
c_strides: torch.Tensor,
group_scale_strides: torch.Tensor,
maybe_schedule: str | None = None,
):
"""
Executes the CUTLASS-based fused-MoE grouped matrix multiplication for the
W4A8 quantization scheme. Uses group-wise quantization (INT4 -> FP8)
and both per-channel + per-token scaling in the epilogue.
Args:
out_tensors:
Output buffer for all experts (updated in-place).
a_tensors:
FP8 (E4M3FN) activations for all experts.
b_tensors:
INT4-packed weight matrix for all experts, packed to INT32
a_scales:
Per-token FP8 activation scales, applied in the epilogue.
b_scales:
Per-channel FP8 weight scales for each expert, applied in the epilogue.
b_group_scales:
FP8 scale values for group-wise INT4 weight blocks.
b_group_size:
Number of elements grouped under each entry of b_group_scales.
expert_offsets:
Cumulative token offsets
problem_sizes:
Per-expert (M, N, K) GEMM sizes used by the grouped GEMM launcher.
a/b/c/group_scale_strides:
Strides describing the memory layout of the input tensors.
maybe_schedule:
Optional override to choose a specific kernel or epilogue schedule.
Returns:
out_tensors updated in-place with the dequantized INT4xFP8 grouped GEMM result.
"""
return torch.ops._C.cutlass_w4a8_moe_mm(
out_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
b_group_scales,
b_group_size,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
c_strides,
group_scale_strides,
maybe_schedule,
)
def cutlass_encode_and_reorder_int4b_grouped(
b_tensors: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops._C.cutlass_encode_and_reorder_int4b_grouped(b_tensors)
if hasattr(torch.ops._C, "permute_cols"): if hasattr(torch.ops._C, "permute_cols"):
@register_fake("_C::permute_cols") @register_fake("_C::permute_cols")
...@@ -1603,7 +1731,7 @@ def scaled_fp8_quant( ...@@ -1603,7 +1731,7 @@ def scaled_fp8_quant(
output, input, scale, scale_ub output, input, scale, scale_ub
) )
else: else:
scale = torch.empty((1, 1), device=input.device, dtype=torch.float32) scale = torch.empty(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else: else:
assert scale.numel() == 1, f"{scale.shape}" assert scale.numel() == 1, f"{scale.shape}"
...@@ -1882,6 +2010,7 @@ def moe_align_block_size( ...@@ -1882,6 +2010,7 @@ def moe_align_block_size(
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor, experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor, num_tokens_post_pad: torch.Tensor,
expert_map: torch.Tensor | None = None,
) -> None: ) -> None:
torch.ops._moe_C.moe_align_block_size( torch.ops._moe_C.moe_align_block_size(
topk_ids, topk_ids,
...@@ -1890,6 +2019,7 @@ def moe_align_block_size( ...@@ -1890,6 +2019,7 @@ def moe_align_block_size(
sorted_token_ids, sorted_token_ids,
experts_ids, experts_ids,
num_tokens_post_pad, num_tokens_post_pad,
expert_map,
) )
...@@ -1924,6 +2054,7 @@ def moe_lora_align_block_size( ...@@ -1924,6 +2054,7 @@ def moe_lora_align_block_size(
num_tokens_post_pad: torch.Tensor, num_tokens_post_pad: torch.Tensor,
adapter_enabled: torch.Tensor, adapter_enabled: torch.Tensor,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
) -> None: ) -> None:
torch.ops._moe_C.moe_lora_align_block_size( torch.ops._moe_C.moe_lora_align_block_size(
topk_ids, topk_ids,
...@@ -1938,6 +2069,7 @@ def moe_lora_align_block_size( ...@@ -1938,6 +2069,7 @@ def moe_lora_align_block_size(
num_tokens_post_pad, num_tokens_post_pad,
adapter_enabled, adapter_enabled,
lora_ids, lora_ids,
expert_map,
) )
......
...@@ -166,6 +166,10 @@ class AttentionBackend(ABC): ...@@ -166,6 +166,10 @@ class AttentionBackend(ABC):
def supports_sink(cls) -> bool: def supports_sink(cls) -> bool:
return False return False
@classmethod
def supports_mm_prefix(cls) -> bool:
return False
@classmethod @classmethod
def is_sparse(cls) -> bool: def is_sparse(cls) -> bool:
return False return False
...@@ -207,6 +211,7 @@ class AttentionBackend(ABC): ...@@ -207,6 +211,7 @@ class AttentionBackend(ABC):
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
use_mm_prefix: bool,
device_capability: "DeviceCapability", device_capability: "DeviceCapability",
attn_type: str, attn_type: str,
) -> list[str]: ) -> list[str]:
...@@ -219,6 +224,10 @@ class AttentionBackend(ABC): ...@@ -219,6 +224,10 @@ class AttentionBackend(ABC):
invalid_reasons.append("kv_cache_dtype not supported") invalid_reasons.append("kv_cache_dtype not supported")
if not cls.supports_block_size(block_size): if not cls.supports_block_size(block_size):
invalid_reasons.append("block_size not supported") invalid_reasons.append("block_size not supported")
if use_mm_prefix and not cls.supports_mm_prefix():
invalid_reasons.append(
"partial multimodal token full attention not supported"
)
if use_mla != cls.is_mla(): if use_mla != cls.is_mla():
if use_mla: if use_mla:
invalid_reasons.append("MLA not supported") invalid_reasons.append("MLA not supported")
...@@ -289,6 +298,16 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -289,6 +298,16 @@ class AttentionImpl(ABC, Generic[T]):
# even if they can return lse (for efficiency reasons) # even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode: bool = False need_to_return_lse_for_decode: bool = False
# Whether this attention implementation supports pre-quantized query input.
# When True, the attention layer will quantize queries before passing them
# to this backend, allowing torch.compile to fuse the quantization with
# previous operations. This is typically supported when using FP8 KV cache
# with compatible attention kernels (e.g., TRT-LLM).
# Subclasses should set this in __init__.
# TODO add support to more backends:
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
dcp_world_size: int dcp_world_size: int
dcp_rank: int dcp_rank: int
...@@ -368,22 +387,6 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -368,22 +387,6 @@ class AttentionImpl(ABC, Generic[T]):
""" """
return False return False
def supports_quant_query_input(self) -> bool:
"""
Check if this attention implementation supports pre-quantized query input.
When True, the attention layer will quantize queries before passing them
to this backend, allowing torch.compile to fuse the quantization with
previous operations. This is typically supported when using FP8 KV cache
with compatible attention kernels (e.g., TRT-LLM).
TODO add support to more backends:
https://github.com/vllm-project/vllm/issues/25584
Returns:
bool: True if the implementation can accept pre-quantized queries.
"""
return False
def process_weights_after_loading(self, act_dtype: torch.dtype): def process_weights_after_loading(self, act_dtype: torch.dtype):
pass pass
......
...@@ -25,6 +25,7 @@ from vllm.config.vllm import VllmConfig ...@@ -25,6 +25,7 @@ from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
UnquantizedLinearMethod, UnquantizedLinearMethod,
...@@ -88,7 +89,10 @@ def maybe_get_vit_flash_attn_backend( ...@@ -88,7 +89,10 @@ def maybe_get_vit_flash_attn_backend(
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA: if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func from aiter import flash_attn_varlen_func
else: else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func try:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
else: else:
flash_attn_varlen_func = None flash_attn_varlen_func = None
...@@ -230,6 +234,10 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -230,6 +234,10 @@ class Attention(nn.Module, AttentionLayerBase):
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None self.has_sink = extra_impl_args.get("sinks") is not None
# NOTE: model_config may be None during certain tests
model_config = vllm_config.model_config
self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm
# During model initialization, the default dtype is set as the model # During model initialization, the default dtype is set as the model
# weight and activation dtype. # weight and activation dtype.
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
...@@ -241,11 +249,30 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -241,11 +249,30 @@ class Attention(nn.Module, AttentionLayerBase):
block_size, block_size,
use_mla=False, use_mla=False,
has_sink=self.has_sink, has_sink=self.has_sink,
use_mm_prefix=self.use_mm_prefix,
attn_type=attn_type, attn_type=attn_type,
) )
else: else:
self.attn_backend = attn_backend self.attn_backend = attn_backend
# prefix caching + batch invariance is currently not supported for
# FLASHINFER and TRITON_MLA.
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "FLASHINFER"
or self.attn_backend.get_name() == "TRITON_MLA"
)
):
logger.warning_once(
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = self.attn_backend.get_impl_cls() impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls( self.impl = impl_cls(
num_heads, num_heads,
...@@ -303,7 +330,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -303,7 +330,7 @@ class Attention(nn.Module, AttentionLayerBase):
self.query_quant = None self.query_quant = None
if ( if (
self.kv_cache_dtype.startswith("fp8") self.kv_cache_dtype.startswith("fp8")
and self.impl.supports_quant_query_input() and self.impl.supports_quant_query_input
): ):
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
...@@ -338,7 +365,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -338,7 +365,7 @@ class Attention(nn.Module, AttentionLayerBase):
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
# check if query quantization is supported # check if query quantization is supported
if self.impl.supports_quant_query_input(): if self.impl.supports_quant_query_input:
query, _ = self.query_quant(query, self._q_scale) query, _ = self.query_quant(query, self._q_scale)
if self.use_output: if self.use_output:
...@@ -623,6 +650,23 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -623,6 +650,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
use_mla=True, use_mla=True,
use_sparse=use_sparse, use_sparse=use_sparse,
) )
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "TRITON_MLA"
or self.attn_backend.get_name() == "FLASHINFER"
)
):
logger.warning_once(
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls()) impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls( self.impl = impl_cls(
num_heads=self.num_heads, num_heads=self.num_heads,
......
...@@ -103,7 +103,7 @@ def create_cross_attention_backend( ...@@ -103,7 +103,7 @@ def create_cross_attention_backend(
# needed here to know how many tokens to attend to from the cached # needed here to know how many tokens to attend to from the cached
# cross-attention KV cache. # cross-attention KV cache.
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
new_metadata.seq_lens_cpu = torch.from_numpy( new_metadata._seq_lens_cpu = torch.from_numpy(
common_attn_metadata.encoder_seq_lens_cpu common_attn_metadata.encoder_seq_lens_cpu
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment