Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import sys
from typing import Any
import pytest
......@@ -10,7 +9,6 @@ from tests.utils import create_new_process_for_each_test
from tests.v1.logits_processors.utils import (
DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN,
DUMMY_LOGITPROC_MODULE,
MAX_TOKENS,
MODEL_NAME,
POOLING_MODEL_NAME,
......@@ -18,7 +16,6 @@ from tests.v1.logits_processors.utils import (
CustomLogitprocSource,
DummyLogitsProcessor,
WrappedPerReqLogitsProcessor,
dummy_module,
prompts,
)
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
......@@ -162,8 +159,6 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource
kwargs: dict[str, list[str | type[LogitsProcessor]]] = {}
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
# Scenario: load logitproc based on fully-qualified class name (FQCN)
# Inject dummy module which defines logitproc
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
# Scenario: load logitproc from provided class object
......
......@@ -14,11 +14,9 @@ from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_te
from tests.v1.logits_processors.utils import (
DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN,
DUMMY_LOGITPROC_MODULE,
MAX_TOKENS,
MODEL_NAME,
TEMP_GREEDY,
dummy_module,
prompts,
)
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
......@@ -47,20 +45,14 @@ def _server_with_logitproc_entrypoint(
main.main()
def _server_with_logitproc_module(
def _server_with_logitproc_fqcn(
env_dict: dict[str, str] | None,
model: str,
vllm_serve_args: list[str],
) -> None:
"""Start vLLM server, inject module with dummy logitproc"""
# Patch `modules` to inject dummy logitproc module
from vllm.entrypoints.cli import main
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
# fork is required for workers to see entrypoint patch
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork"
if env_dict is not None:
os.environ.update(env_dict)
......@@ -99,7 +91,7 @@ def server(default_server_args, request, monkeypatch):
if request.param:
# Launch server, append FQCN argument, inject dummy logitproc module
args = default_server_args + request.param
_server_fxn = _server_with_logitproc_module
_server_fxn = _server_with_logitproc_fqcn
else:
# Launch server, inject dummy logitproc entrypoint
args = default_server_args
......
......@@ -27,7 +27,7 @@ DUMMY_LOGITPROC_ARG = "target_token"
TEMP_GREEDY = 0.0
MAX_TOKENS = 20
DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc"
DUMMY_LOGITPROC_MODULE = "DummyModule"
DUMMY_LOGITPROC_MODULE = "tests.v1.logits_processors.utils"
DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.metrics.stats import IterationStats
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, RequestStateStats
def test_iteration_stats_repr():
iteration_stats = IterationStats()
assert repr(iteration_stats).startswith("IterationStats(")
def test_prefill_kv_computed_with_cache():
"""Test that prefill KV compute correctly excludes cached tokens."""
iteration_stats = IterationStats()
req_stats = RequestStateStats(arrival_time=0.0)
req_stats.scheduled_ts = 0.1
req_stats.first_token_ts = 0.5
req_stats.last_token_ts = 5.0
req_stats.num_generation_tokens = 50
# Case 1: With prefix cache (1200 tokens cached)
iteration_stats.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=10000,
max_tokens_param=100,
req_stats=req_stats,
num_cached_tokens=1200,
)
finished_req = iteration_stats.finished_requests[0]
assert finished_req.num_prompt_tokens == 10000
assert finished_req.num_cached_tokens == 1200
# Verify calculation: prefill KV = prompt tokens - cached tokens
prefill_kv_computed = finished_req.num_prompt_tokens - max(
finished_req.num_cached_tokens, 0
)
assert prefill_kv_computed == 8800 # 10000 - 1200
def test_prefill_kv_computed_no_cache():
"""Test prefill KV compute without prefix caching."""
iteration_stats = IterationStats()
req_stats = RequestStateStats(arrival_time=0.0)
req_stats.scheduled_ts = 0.1
req_stats.first_token_ts = 0.5
req_stats.last_token_ts = 2.0
req_stats.num_generation_tokens = 10
# Case 2: No prefix cache
iteration_stats.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=2000,
max_tokens_param=100,
req_stats=req_stats,
num_cached_tokens=0,
)
finished_req = iteration_stats.finished_requests[0]
assert finished_req.num_prompt_tokens == 2000
assert finished_req.num_cached_tokens == 0
# Verify calculation: prefill KV = full prompt when no cache
prefill_kv_computed = finished_req.num_prompt_tokens - max(
finished_req.num_cached_tokens, 0
)
assert prefill_kv_computed == 2000
def test_prefill_kv_computed_edge_cases():
"""Test edge cases for prefill KV compute calculation."""
iteration_stats = IterationStats()
req_stats = RequestStateStats(arrival_time=0.0)
req_stats.scheduled_ts = 0.1
req_stats.first_token_ts = 0.5
req_stats.last_token_ts = 1.0
req_stats.num_generation_tokens = 1
# Case 3: Negative num_cached_tokens (shouldn't happen, but handle gracefully)
iteration_stats.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=100,
max_tokens_param=10,
req_stats=req_stats,
num_cached_tokens=-1,
)
finished_req = iteration_stats.finished_requests[0]
# max() should handle negative values
prefill_kv_computed = finished_req.num_prompt_tokens - max(
finished_req.num_cached_tokens, 0
)
assert prefill_kv_computed == 100 # Should treat negative as 0
# Case 4: All tokens cached (shouldn't happen in practice)
iteration_stats2 = IterationStats()
iteration_stats2.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=100,
max_tokens_param=10,
req_stats=req_stats,
num_cached_tokens=100,
)
finished_req2 = iteration_stats2.finished_requests[0]
prefill_kv_computed2 = finished_req2.num_prompt_tokens - max(
finished_req2.num_cached_tokens, 0
)
assert prefill_kv_computed2 == 0 # All cached, nothing computed
......@@ -339,7 +339,7 @@ def test_load_model(
"multi-token eagle spec decode on current platform"
)
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Setup draft model mock
......@@ -436,7 +436,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
"because it requires special input mocking."
)
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Use GPU device
......@@ -543,6 +543,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN
)
elif attn_backend == "ROCM_AITER_FA":
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.ROCM_AITER_FA
)
else:
raise ValueError(f"Unsupported attention backend: {attn_backend}")
......
......@@ -47,7 +47,7 @@ def test_eagle_max_len(
"multi-token eagle spec decode on current platform"
)
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")
llm = LLM(
......@@ -82,7 +82,7 @@ def test_eagle_max_len(
len(o.prompt_token_ids)
< 80
< len(o.prompt_token_ids) + len(o.outputs[0].token_ids)
< 200
<= 200
), (
"This test is only meaningful if the output "
"is longer than the eagle max length"
......
......@@ -5,6 +5,7 @@ import torch
from vllm.config import SpeculativeConfig
from vllm.model_executor.models.interfaces import supports_eagle3
from vllm.platforms import current_platform
@pytest.mark.parametrize(
......@@ -21,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3
pytest.param(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
id="qwen3-eagle3-speculator-w4a16-verifier",
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="The tests are skipped on rocm platform.",
),
),
],
)
......
......@@ -88,8 +88,8 @@ def forward_attention(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc.cpu(),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.cpu(),
num_computed_tokens_cpu=context_lens.cpu(),
_seq_lens_cpu=seq_lens.cpu(),
_num_computed_tokens_cpu=context_lens.cpu(),
num_reqs=batch_size,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from concurrent.futures import Future
import pytest
from transformers import AutoTokenizer
from vllm.config import StructuredOutputsConfig, VllmConfig
from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.request import Request
......@@ -116,3 +121,72 @@ def test_grammar_bitmask_with_specdec():
) # EOS not the final token
grammar_bitmask(request, prompt[i:]) # EOS not present
grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id])
@pytest.mark.parametrize("async_grammar", [True, False])
def test_grammar_init_async_and_sync(async_grammar):
"""Test grammar initialization works correctly in both async and sync modes.
This test validates that the distributed_executor_backend config option
correctly controls whether grammar compilation happens asynchronously
(via executor.submit) or synchronously. When set to "external_launcher",
grammar compilation is synchronous to avoid deadlocks.
"""
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
prompt = tokenizer.encode('{"a": "b"}')
# Use "external_launcher" for sync mode, None for async mode
executor_backend = None if async_grammar else "external_launcher"
vllm_config = VllmConfig(
model_config=ModelConfig(tokenizer=TOKENIZER),
structured_outputs_config=StructuredOutputsConfig(backend="guidance"),
parallel_config=ParallelConfig(distributed_executor_backend=executor_backend),
)
structured_output_manager = StructuredOutputManager(vllm_config)
sampling_params = SamplingParams(
structured_outputs=StructuredOutputsParams(
json='{"type": "object"}',
),
)
sampling_params.structured_outputs._backend = "guidance"
request = Request(
"test_request",
prompt_token_ids=prompt,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=tokenizer.eos_token_id,
)
structured_output_manager.grammar_init(request)
# Check the internal _grammar type immediately after init
# Before _check_grammar_completion is called, async mode should have a Future
raw_grammar = request.structured_output_request._grammar
if async_grammar:
assert isinstance(raw_grammar, Future), (
"Async mode should store a Future before completion"
)
else:
assert not isinstance(raw_grammar, Future), (
"Sync mode should store the grammar directly, not a Future"
)
# Wait for grammar to be ready (handles both async and sync cases)
start_time = time.time()
while not request.structured_output_request._check_grammar_completion():
if time.time() - start_time > 5: # 5-second timeout
pytest.fail("Grammar compilation timed out")
time.sleep(0.01)
# After completion, _grammar should no longer be a Future
assert not isinstance(request.structured_output_request._grammar, Future)
# Verify grammar is properly initialized and functional
grammar = request.structured_output_request.grammar
assert grammar is not None
assert not grammar.is_terminated()
# Verify the grammar can accept valid tokens
assert grammar.accept_tokens(request.request_id, prompt)
......@@ -70,6 +70,7 @@ class TestReasoningStructuredOutput:
request.use_structured_output = True
request.prompt_token_ids = [1, 2, 3, 4, 5]
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
request.num_computed_tokens = 5
return request
def test_should_fill_bitmask_with_enable_in_reasoning(
......
......@@ -104,22 +104,31 @@ class MyRequest(msgspec.Struct):
def test_multimodal_kwargs():
e1 = MultiModalFieldElem(
"audio", "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField()
"audio",
"a0",
torch.zeros(1000, dtype=torch.bfloat16),
MultiModalBatchedField(),
)
e2 = MultiModalFieldElem(
"video",
"v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalFlatField([[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0),
MultiModalFlatField(
slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]],
dim=0,
),
)
e3 = MultiModalFieldElem(
"image", "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(4)
"image",
"i0",
torch.zeros(1000, dtype=torch.int32),
MultiModalSharedField(batch_size=4),
)
e4 = MultiModalFieldElem(
"image",
"i1",
torch.zeros(1000, dtype=torch.int32),
MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2),
MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2),
)
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
......@@ -138,8 +147,8 @@ def test_multimodal_kwargs():
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14306, +-20 for minor changes
assert 14275 <= total_len <= 14325
# expected total encoding length, should be 14395, +-20 for minor changes
assert 14375 <= total_len <= 14425
decoded = decoder.decode(encoded).mm[0]
assert isinstance(decoded, MultiModalKwargsItems)
......
......@@ -6,8 +6,10 @@ import pytest
import torch
from vllm.attention.backends.abstract import MultipleOf
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import Attention
from vllm.config import (
AttentionConfig,
CacheConfig,
ModelConfig,
ParallelConfig,
......@@ -761,7 +763,11 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[1] == layer_1
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
@pytest.mark.skipif(
current_platform.is_rocm(),
reason="Attention backend FLASHINFER is not supported on ROCm.",
)
def test_hybrid_attention_mamba_tensor_shapes():
"""
The GPU model runner creates different views into the
KVCacheTensors for the attention and mamba layers
......@@ -802,11 +808,13 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
cache_dtype="auto",
)
parallel_config = ParallelConfig()
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASHINFER)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
attention_config=attention_config,
)
layer_0 = "model.layers.0.self_attn.attn"
......@@ -816,8 +824,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
layer_4 = "model.layers.4.mixer"
layer_5 = "model.layers.5.mixer"
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with set_current_vllm_config(vllm_config):
hf_config = vllm_config.model_config.hf_config
fwd_context = {}
for key in [layer_0, layer_1]:
......@@ -849,9 +856,6 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
assert fwd_context is not None
vllm_ctx = vllm_config.compilation_config.static_forward_context
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
runner = GPUModelRunner(vllm_config, DEVICE)
kv_cache_spec = runner.get_kv_cache_spec()
......
......@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import vllm.envs as envs
from vllm.profiler.gpu_profiler import WorkerProfiler
from vllm.config import ProfilerConfig
from vllm.profiler.wrapper import WorkerProfiler
class ConcreteWorkerProfiler(WorkerProfiler):
......@@ -11,11 +11,11 @@ class ConcreteWorkerProfiler(WorkerProfiler):
A basic implementation of a worker profiler for testing purposes.
"""
def __init__(self):
def __init__(self, profiler_config: ProfilerConfig):
self.start_call_count = 0
self.stop_call_count = 0
self.should_fail_start = False
super().__init__()
super().__init__(profiler_config)
def _start(self) -> None:
if self.should_fail_start:
......@@ -26,17 +26,19 @@ class ConcreteWorkerProfiler(WorkerProfiler):
self.stop_call_count += 1
@pytest.fixture(autouse=True)
def reset_mocks():
"""Fixture to reset mocks and env variables before each test."""
envs.VLLM_PROFILER_DELAY_ITERS = 0
envs.VLLM_PROFILER_MAX_ITERS = 0
@pytest.fixture
def default_profiler_config():
return ProfilerConfig(
profiler="torch",
torch_profiler_dir="/tmp/mock",
delay_iterations=0,
max_iterations=0,
)
def test_immediate_start_stop():
def test_immediate_start_stop(default_profiler_config):
"""Test standard start without delay."""
profiler = ConcreteWorkerProfiler()
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
assert profiler._running is True
assert profiler._active is True
......@@ -48,10 +50,10 @@ def test_immediate_start_stop():
assert profiler.stop_call_count == 1
def test_delayed_start():
def test_delayed_start(default_profiler_config):
"""Test that profiler waits for N steps before actually starting."""
envs.VLLM_PROFILER_DELAY_ITERS = 2
profiler = ConcreteWorkerProfiler()
default_profiler_config.delay_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
# User requests start
profiler.start()
......@@ -71,10 +73,10 @@ def test_delayed_start():
assert profiler.start_call_count == 1
def test_max_iterations():
def test_max_iterations(default_profiler_config):
"""Test that profiler stops automatically after max iterations."""
envs.VLLM_PROFILER_MAX_ITERS = 2
profiler = ConcreteWorkerProfiler()
default_profiler_config.max_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
assert profiler._running is True
......@@ -95,12 +97,11 @@ def test_max_iterations():
assert profiler.stop_call_count == 1
def test_delayed_start_and_max_iters():
def test_delayed_start_and_max_iters(default_profiler_config):
"""Test combined delayed start and max iterations."""
envs.VLLM_PROFILER_DELAY_ITERS = 2
envs.VLLM_PROFILER_MAX_ITERS = 2
profiler = ConcreteWorkerProfiler()
default_profiler_config.delay_iterations = 2
default_profiler_config.max_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
# Step 1
......@@ -127,9 +128,9 @@ def test_delayed_start_and_max_iters():
assert profiler.stop_call_count == 1
def test_idempotency():
def test_idempotency(default_profiler_config):
"""Test that calling start/stop multiple times doesn't break logic."""
profiler = ConcreteWorkerProfiler()
profiler = ConcreteWorkerProfiler(default_profiler_config)
# Double Start
profiler.start()
......@@ -142,10 +143,10 @@ def test_idempotency():
assert profiler.stop_call_count == 1 # Should only stop once
def test_step_inactive():
def test_step_inactive(default_profiler_config):
"""Test that stepping while inactive does nothing."""
envs.VLLM_PROFILER_DELAY_ITERS = 2
profiler = ConcreteWorkerProfiler()
default_profiler_config.delay_iterations = 2
profiler = ConcreteWorkerProfiler(default_profiler_config)
# Not started yet
profiler.step()
......@@ -155,9 +156,9 @@ def test_step_inactive():
assert profiler.start_call_count == 0
def test_start_failure():
def test_start_failure(default_profiler_config):
"""Test behavior when the underlying _start method raises exception."""
profiler = ConcreteWorkerProfiler()
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.should_fail_start = True
profiler.start()
......@@ -168,9 +169,9 @@ def test_start_failure():
assert profiler.start_call_count == 0 # Logic failed inside start
def test_shutdown():
def test_shutdown(default_profiler_config):
"""Test that shutdown calls stop only if running."""
profiler = ConcreteWorkerProfiler()
profiler = ConcreteWorkerProfiler(default_profiler_config)
# Case 1: Not running
profiler.shutdown()
......@@ -182,10 +183,10 @@ def test_shutdown():
assert profiler.stop_call_count == 1
def test_mixed_delay_and_stop():
def test_mixed_delay_and_stop(default_profiler_config):
"""Test manual stop during the delay period."""
envs.VLLM_PROFILER_DELAY_ITERS = 5
profiler = ConcreteWorkerProfiler()
default_profiler_config.delay_iterations = 5
profiler = ConcreteWorkerProfiler(default_profiler_config)
profiler.start()
profiler.step()
......
......@@ -10,9 +10,10 @@ set -ex
CUDA_HOME=${CUDA_HOME:-/usr/local/cuda}
PPLX_COMMIT_HASH=${PPLX_COMMIT_HASH:-"12cecfd"}
DEEPEP_COMMIT_HASH=${DEEPEP_COMMIT_HASH:-"73b6ea4"}
NVSHMEM_VER=3.3.9
NVSHMEM_VER=3.3.24 # Suppports both CUDA 12 and 13
WORKSPACE=${WORKSPACE:-$(pwd)/ep_kernels_workspace}
MODE=${MODE:-install}
CUDA_VERSION_MAJOR=$(${CUDA_HOME}/bin/nvcc --version | egrep -o "release [0-9]+" | cut -d ' ' -f 2)
# Parse arguments
while [[ $# -gt 0 ]]; do
......@@ -75,11 +76,9 @@ ARCH=$(uname -m)
case "${ARCH,,}" in
x86_64|amd64)
NVSHMEM_SUBDIR="linux-x86_64"
NVSHMEM_FILE="libnvshmem-linux-x86_64-${NVSHMEM_VER}_cuda12-archive.tar.xz"
;;
aarch64|arm64)
NVSHMEM_SUBDIR="linux-sbsa"
NVSHMEM_FILE="libnvshmem-linux-sbsa-${NVSHMEM_VER}_cuda12-archive.tar.xz"
;;
*)
echo "Unsupported architecture: ${ARCH}" >&2
......@@ -87,6 +86,7 @@ case "${ARCH,,}" in
;;
esac
NVSHMEM_FILE="libnvshmem-${NVSHMEM_SUBDIR}-${NVSHMEM_VER}_cuda${CUDA_VERSION_MAJOR}-archive.tar.xz"
NVSHMEM_URL="https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/${NVSHMEM_SUBDIR}/${NVSHMEM_FILE}"
pushd "$WORKSPACE"
......@@ -142,13 +142,6 @@ clone_repo() {
fi
}
deepep_cuda13_patch() {
cuda_version_major=$(${CUDA_HOME}/bin/nvcc --version | egrep -o "release [0-9]+" | cut -d ' ' -f 2)
if [ ${cuda_version_major} -ge 13 ]; then
sed -i "s|f'{nvshmem_dir}/include']|f'{nvshmem_dir}/include', '${CUDA_HOME}/include/cccl']|" "setup.py"
fi
}
do_build() {
local repo=$1
local name=$2
......@@ -160,8 +153,9 @@ do_build() {
clone_repo "$repo" "$name" "$key" "$commit"
cd "$name"
if [ "$name" == "DeepEP" ]; then
deepep_cuda13_patch
# DeepEP CUDA 13 patch
if [[ "$name" == "DeepEP" && "${CUDA_VERSION_MAJOR}" -ge 13 ]]; then
sed -i "s|f'{nvshmem_dir}/include']|f'{nvshmem_dir}/include', '${CUDA_HOME}/include/cccl']|" "setup.py"
fi
if [ "$MODE" = "install" ]; then
......
......@@ -3,9 +3,7 @@
import glob
requires_files = glob.glob("requirements/*.txt")
requires_files += ["pyproject.toml"]
for file in requires_files:
for file in (*glob.glob("requirements/*.txt"), "pyproject.toml"):
print(f">>> cleaning {file}")
with open(file) as f:
lines = f.readlines()
......@@ -17,5 +15,4 @@ for file in requires_files:
f.write(line)
else:
print(line.strip())
print(f"<<< done cleaning {file}")
print()
print(f"<<< done cleaning {file}\n")
......@@ -9,6 +9,8 @@ import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
_FP8_DTYPE = current_platform.fp8_dtype()
def is_aiter_found() -> bool:
from importlib.util import find_spec
......@@ -22,6 +24,15 @@ def is_aiter_found() -> bool:
# we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND = is_aiter_found()
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if IS_AITER_FOUND:
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if
......@@ -43,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable:
return wrapper
def _rocm_aiter_group_fp8_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
from aiter import QuantType, dtypes, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8)
def _rocm_aiter_group_fp8_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter import dtypes
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
......@@ -283,6 +264,28 @@ def _rocm_aiter_grouped_topk_fake(
pass
# Cache whether aiter supports FP8 MLA parameters
_AITER_MLA_SUPPORTS_FP8: bool | None = None
def _check_aiter_mla_fp8_support() -> bool:
"""Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters."""
global _AITER_MLA_SUPPORTS_FP8
if _AITER_MLA_SUPPORTS_FP8 is None:
try:
import inspect
from aiter.mla import mla_decode_fwd
sig = inspect.signature(mla_decode_fwd)
_AITER_MLA_SUPPORTS_FP8 = (
"q_scale" in sig.parameters and "kv_scale" in sig.parameters
)
except Exception:
_AITER_MLA_SUPPORTS_FP8 = False
return _AITER_MLA_SUPPORTS_FP8
def _rocm_aiter_mla_decode_fwd_impl(
q: torch.Tensor,
kv_buffer: torch.Tensor,
......@@ -299,6 +302,16 @@ def _rocm_aiter_mla_decode_fwd_impl(
) -> None:
from aiter.mla import mla_decode_fwd
kwargs = {
"sm_scale": sm_scale,
"logit_cap": logit_cap,
}
# Only pass q_scale and kv_scale if the aiter library supports them
if _check_aiter_mla_fp8_support():
kwargs["q_scale"] = q_scale
kwargs["kv_scale"] = kv_scale
mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
......@@ -308,10 +321,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
kv_indices,
kv_last_page_lens,
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
q_scale=q_scale,
kv_scale=kv_scale,
**kwargs,
)
......@@ -438,6 +448,195 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
return torch.empty_like(x), torch.empty_like(residual)
def _rocm_aiter_per_tensor_quant_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.quant import per_tensor_quant_hip
return per_tensor_quant_hip(x, scale, quant_dtype)
def _rocm_aiter_per_tensor_quant_fake(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x, dtype=quant_dtype), torch.empty(
1, dtype=torch.float32, device=x.device
)
def _rocm_aiter_per_token_quant_impl(
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.quant import dynamic_per_token_scaled_quant
assert quant_dtype in [torch.int8, _FP8_DTYPE]
out_shape = x.shape
out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device)
if scale is None:
scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
dynamic_per_token_scaled_quant(
out,
x,
scale,
scale_ub=None,
shuffle_scale=False,
num_rows=None,
num_rows_factor=1,
)
return out, scale
def _rocm_aiter_per_token_quant_fake(
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
out_shape = x.shape
return (
torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device),
torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
)
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
x,
weight,
variance_epsilon,
None,
None,
None,
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
res1=residual,
)
return (x_quant, x_quant_scales, res)
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
torch.empty_like(residual, device=residual.device),
)
def _rocm_aiter_rmsnorm_fp8_group_quant_impl(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
x,
weight,
variance_epsilon,
None,
None,
None,
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
res1=None,
)
return (x_quant, x_quant_scales)
def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
)
def _rocm_aiter_group_fp8_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
from aiter import QuantType, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE)
def _rocm_aiter_group_fp8_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device)
out_bs = torch.empty(
(
M,
(N + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant
return act_mul_and_fp8_group_quant(
x,
activation="silu",
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
)
def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
x: torch.Tensor,
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
assert N % 2 == 0
N_half = N // 2
x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device)
out_bs = torch.empty(
(
M,
(N_half + group_size - 1) // group_size,
),
dtype=torch.float32,
device=x.device,
)
return x_fp8, out_bs
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
......@@ -473,7 +672,7 @@ class rocm_aiter_ops:
@if_aiter_supported
def is_linear_fp8_enaled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
return cls.is_linear_enabled()
@classmethod
@if_aiter_supported
......@@ -548,14 +747,6 @@ class rocm_aiter_ops:
)
# register all the custom ops here
direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_group_fp8_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=_rocm_aiter_asm_moe_tkw1_impl,
......@@ -615,27 +806,62 @@ class rocm_aiter_ops:
direct_register_custom_op(
op_name="rocm_aiter_gemm_a8w8_blockscale",
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=_rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_fp8_group_quant",
op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant",
op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl,
fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl,
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_group_fp8_quant",
op_func=_rocm_aiter_group_fp8_quant_impl,
fake_impl=_rocm_aiter_group_fp8_quant_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_per_tensor_quant",
op_func=_rocm_aiter_per_tensor_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_per_tensor_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_per_token_quant",
op_func=_rocm_aiter_per_token_quant_impl,
mutates_args=["scale"],
fake_impl=_rocm_aiter_per_token_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
_OPS_REGISTERED = True
@staticmethod
......@@ -830,6 +1056,22 @@ class rocm_aiter_ops:
kv_scale=kv_scale,
)
@staticmethod
def per_tensor_quant(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_per_tensor_quant(x, quant_dtype, scale)
@staticmethod
def per_token_quant(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale)
@staticmethod
def triton_fp4_gemm_dynamic_qaunt(
x: torch.Tensor,
......
......@@ -441,6 +441,46 @@ def rms_norm_dynamic_per_token_quant(
return output, scales
# fused quant layer norm ops blocked
def rms_norm_per_block_quant(
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
group_size: list[int],
scale_ub: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
is_scale_transposed: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert len(group_size) == 2
output = torch.empty_like(input, dtype=quant_dtype)
if is_scale_transposed:
scales = torch.empty(
(input.shape[-1] // group_size[1], input.numel() // input.shape[-1]),
device=input.device,
dtype=torch.float32,
).transpose(0, 1)
else:
scales = torch.empty(
(input.numel() // input.shape[-1], input.shape[-1] // group_size[1]),
device=input.device,
dtype=torch.float32,
)
torch.ops._C.rms_norm_per_block_quant(
output,
input,
weight,
scales,
epsilon,
scale_ub,
residual,
group_size[1],
is_scale_transposed,
)
return output, scales
# quantization ops
# awq
def awq_dequantize(
......@@ -660,6 +700,10 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor:
return torch.empty_like(b, memory_format=torch.contiguous_format)
@register_fake("_C::cutlass_encode_and_reorder_int4b_grouped")
def cutlass_encode_and_reorder_int4b_grouped_fake(b: torch.Tensor) -> torch.Tensor:
return torch.empty_like(b, memory_format=torch.contiguous_format)
if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
......@@ -1023,6 +1067,7 @@ def get_cutlass_moe_mm_problem_sizes(
n: int,
k: int,
blockscale_offsets: torch.Tensor | None = None,
force_swap_ab: bool | None = None,
):
"""
Compute only the per-expert problem sizes needed by the two grouped matrix
......@@ -1032,9 +1077,20 @@ def get_cutlass_moe_mm_problem_sizes(
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
multiplication for the two grouped MMs
used in the fused MoE operation.
Optional:
- force_swap_ab: If set to True or False, explicitly enable or disable the
A/B input swap optimization. If None (default), the swap
is selected automatically based on tensor sizes.
"""
return torch.ops._C.get_cutlass_moe_mm_problem_sizes(
topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, blockscale_offsets
topk_ids,
problem_sizes1,
problem_sizes2,
num_experts,
n,
k,
blockscale_offsets,
force_swap_ab,
)
......@@ -1422,6 +1478,78 @@ def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor:
return torch.ops._C.cutlass_encode_and_reorder_int4b(b)
def cutlass_w4a8_moe_mm(
out_tensors: torch.Tensor,
a_tensors: torch.Tensor,
b_tensors: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
b_group_scales: torch.Tensor,
b_group_size: int,
expert_offsets: torch.Tensor,
problem_sizes: torch.Tensor,
a_strides: torch.Tensor,
b_strides: torch.Tensor,
c_strides: torch.Tensor,
group_scale_strides: torch.Tensor,
maybe_schedule: str | None = None,
):
"""
Executes the CUTLASS-based fused-MoE grouped matrix multiplication for the
W4A8 quantization scheme. Uses group-wise quantization (INT4 -> FP8)
and both per-channel + per-token scaling in the epilogue.
Args:
out_tensors:
Output buffer for all experts (updated in-place).
a_tensors:
FP8 (E4M3FN) activations for all experts.
b_tensors:
INT4-packed weight matrix for all experts, packed to INT32
a_scales:
Per-token FP8 activation scales, applied in the epilogue.
b_scales:
Per-channel FP8 weight scales for each expert, applied in the epilogue.
b_group_scales:
FP8 scale values for group-wise INT4 weight blocks.
b_group_size:
Number of elements grouped under each entry of b_group_scales.
expert_offsets:
Cumulative token offsets
problem_sizes:
Per-expert (M, N, K) GEMM sizes used by the grouped GEMM launcher.
a/b/c/group_scale_strides:
Strides describing the memory layout of the input tensors.
maybe_schedule:
Optional override to choose a specific kernel or epilogue schedule.
Returns:
out_tensors updated in-place with the dequantized INT4xFP8 grouped GEMM result.
"""
return torch.ops._C.cutlass_w4a8_moe_mm(
out_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
b_group_scales,
b_group_size,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
c_strides,
group_scale_strides,
maybe_schedule,
)
def cutlass_encode_and_reorder_int4b_grouped(
b_tensors: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops._C.cutlass_encode_and_reorder_int4b_grouped(b_tensors)
if hasattr(torch.ops._C, "permute_cols"):
@register_fake("_C::permute_cols")
......@@ -1603,7 +1731,7 @@ def scaled_fp8_quant(
output, input, scale, scale_ub
)
else:
scale = torch.empty((1, 1), device=input.device, dtype=torch.float32)
scale = torch.empty(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"
......@@ -1882,6 +2010,7 @@ def moe_align_block_size(
sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
expert_map: torch.Tensor | None = None,
) -> None:
torch.ops._moe_C.moe_align_block_size(
topk_ids,
......@@ -1890,6 +2019,7 @@ def moe_align_block_size(
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
expert_map,
)
......@@ -1924,6 +2054,7 @@ def moe_lora_align_block_size(
num_tokens_post_pad: torch.Tensor,
adapter_enabled: torch.Tensor,
lora_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
) -> None:
torch.ops._moe_C.moe_lora_align_block_size(
topk_ids,
......@@ -1938,6 +2069,7 @@ def moe_lora_align_block_size(
num_tokens_post_pad,
adapter_enabled,
lora_ids,
expert_map,
)
......
......@@ -166,6 +166,10 @@ class AttentionBackend(ABC):
def supports_sink(cls) -> bool:
return False
@classmethod
def supports_mm_prefix(cls) -> bool:
return False
@classmethod
def is_sparse(cls) -> bool:
return False
......@@ -207,6 +211,7 @@ class AttentionBackend(ABC):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]:
......@@ -219,6 +224,10 @@ class AttentionBackend(ABC):
invalid_reasons.append("kv_cache_dtype not supported")
if not cls.supports_block_size(block_size):
invalid_reasons.append("block_size not supported")
if use_mm_prefix and not cls.supports_mm_prefix():
invalid_reasons.append(
"partial multimodal token full attention not supported"
)
if use_mla != cls.is_mla():
if use_mla:
invalid_reasons.append("MLA not supported")
......@@ -289,6 +298,16 @@ class AttentionImpl(ABC, Generic[T]):
# even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode: bool = False
# Whether this attention implementation supports pre-quantized query input.
# When True, the attention layer will quantize queries before passing them
# to this backend, allowing torch.compile to fuse the quantization with
# previous operations. This is typically supported when using FP8 KV cache
# with compatible attention kernels (e.g., TRT-LLM).
# Subclasses should set this in __init__.
# TODO add support to more backends:
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
dcp_world_size: int
dcp_rank: int
......@@ -368,22 +387,6 @@ class AttentionImpl(ABC, Generic[T]):
"""
return False
def supports_quant_query_input(self) -> bool:
"""
Check if this attention implementation supports pre-quantized query input.
When True, the attention layer will quantize queries before passing them
to this backend, allowing torch.compile to fuse the quantization with
previous operations. This is typically supported when using FP8 KV cache
with compatible attention kernels (e.g., TRT-LLM).
TODO add support to more backends:
https://github.com/vllm-project/vllm/issues/25584
Returns:
bool: True if the implementation can accept pre-quantized queries.
"""
return False
def process_weights_after_loading(self, act_dtype: torch.dtype):
pass
......
......@@ -25,6 +25,7 @@ from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
UnquantizedLinearMethod,
......@@ -88,7 +89,10 @@ def maybe_get_vit_flash_attn_backend(
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
try:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
else:
flash_attn_varlen_func = None
......@@ -230,6 +234,10 @@ class Attention(nn.Module, AttentionLayerBase):
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
# NOTE: model_config may be None during certain tests
model_config = vllm_config.model_config
self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
......@@ -241,11 +249,30 @@ class Attention(nn.Module, AttentionLayerBase):
block_size,
use_mla=False,
has_sink=self.has_sink,
use_mm_prefix=self.use_mm_prefix,
attn_type=attn_type,
)
else:
self.attn_backend = attn_backend
# prefix caching + batch invariance is currently not supported for
# FLASHINFER and TRITON_MLA.
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "FLASHINFER"
or self.attn_backend.get_name() == "TRITON_MLA"
)
):
logger.warning_once(
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(
num_heads,
......@@ -303,7 +330,7 @@ class Attention(nn.Module, AttentionLayerBase):
self.query_quant = None
if (
self.kv_cache_dtype.startswith("fp8")
and self.impl.supports_quant_query_input()
and self.impl.supports_quant_query_input
):
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
......@@ -338,7 +365,7 @@ class Attention(nn.Module, AttentionLayerBase):
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
# check if query quantization is supported
if self.impl.supports_quant_query_input():
if self.impl.supports_quant_query_input:
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
......@@ -623,6 +650,23 @@ class MLAAttention(nn.Module, AttentionLayerBase):
use_mla=True,
use_sparse=use_sparse,
)
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "TRITON_MLA"
or self.attn_backend.get_name() == "FLASHINFER"
)
):
logger.warning_once(
"Disabling prefix caching for TRITON_MLA / FLASHINFER "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls(
num_heads=self.num_heads,
......
......@@ -103,7 +103,7 @@ def create_cross_attention_backend(
# needed here to know how many tokens to attend to from the cached
# cross-attention KV cache.
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
new_metadata.seq_lens_cpu = torch.from_numpy(
new_metadata._seq_lens_cpu = torch.from_numpy(
common_attn_metadata.encoder_seq_lens_cpu
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment