Commit c1cacde6 authored by weishb's avatar weishb
Browse files

vllm-omni_0.15.0.rc1+fix1 first commit

parent 35607782
import argparse
import asyncio
from typing import Any
from vllm.benchmarks.serve import main_async
def main(args: argparse.Namespace) -> dict[str, Any]:
return asyncio.run(main_async(args))
"""
Configuration module for vLLM-Omni.
"""
from vllm_omni.config.lora import LoRAConfig
from vllm_omni.config.model import OmniModelConfig
__all__ = [
"OmniModelConfig",
"LoRAConfig",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# for now, it suffices to use vLLM's implementation directly
# as this is a user-facing variable, defined here to so that user can directly import LoRAConfig from vllm_omni
from vllm.config.lora import LoRAConfig
__all__ = ["LoRAConfig"]
import warnings
from dataclasses import field
from typing import Any
import torch
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass
from vllm.config import ModelConfig, config
from vllm.config.model import (
_RUNNER_CONVERTS,
_get_and_verify_dtype,
get_served_model_name,
)
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
from vllm.config.pooler import PoolerConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
get_config,
get_hf_image_processor_config,
get_hf_text_config,
get_pooling_config,
)
from vllm.transformers_utils.gguf_utils import is_gguf, maybe_patch_hf_config_from_gguf
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.v1.attention.backends.registry import AttentionBackendEnum
import vllm_omni.model_executor.models as me_models
logger = init_logger(__name__)
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class OmniModelConfig(ModelConfig):
"""Configuration for Omni models, extending the base ModelConfig.
This configuration class extends the base vLLM ModelConfig with
omni-specific fields for multi-stage pipeline processing.
Attributes:
stage_id: Identifier for the stage in a multi-stage pipeline (default: 0)
async_chunk: If set to True, perform async chunk
model_stage: Stage type identifier, e.g., "thinker" or "talker"
(default: "thinker")
model_arch: Model architecture name
(default: "Qwen2_5OmniForConditionalGeneration")
engine_output_type: Optional output type specification for the engine.
Used to route outputs to appropriate processors (e.g., "image",
"audio", "latents"). If None, output type is inferred.
stage_connector_config: Stage connector configuration dictionary.
Contains "name" (connector name), "extra" (extra connector config).
Example:
>>> config = OmniModelConfig(
... stage_id=0,
... model_stage="thinker",
... model_arch="Qwen2_5OmniForConditionalGeneration"
... )
"""
stage_id: int = 0
async_chunk: bool = False
model_stage: str = "thinker"
model_arch: str = "Qwen2_5OmniForConditionalGeneration"
engine_output_type: str | None = None
hf_config_name: str | None = None
custom_process_next_stage_input_func: str | None = None
stage_connector_config: dict[str, Any] = field(
default_factory=lambda: {
"name": "SharedMemoryConnector",
"extra": {},
}
)
omni_kv_config: dict | None = None
@property
def registry(self):
return me_models.OmniModelRegistry
@property
def architectures(self) -> list[str]:
return [self.model_arch]
def draw_hf_text_config(self):
# transformers' get_text_config method is used to get the text config from thinker_config.
# to handle the case that each model stage has their own text config,
# we need to draw the text config from the corresponding model stage.
if self.hf_config_name is None:
return get_hf_text_config(self.hf_config)
try:
# Try to get the stage-specific config (e.g., thinker_config, talker_config)
stage_config = getattr(self.hf_config, self.hf_config_name)
return stage_config.get_text_config()
except AttributeError:
# Fallback: if the attribute doesn't exist, use the default get_hf_text_config
logger.warning(
f"Config attribute '{self.hf_config_name}' not found in hf_config, "
"falling back to default get_hf_text_config"
)
return get_hf_text_config(self.hf_config)
def __post_init__(
self,
# Multimodal config init vars
limit_mm_per_prompt: dict[str, int | dict[str, int]] | None,
enable_mm_embeds: bool | None,
media_io_kwargs: dict[str, dict[str, Any]] | None,
mm_processor_kwargs: dict[str, Any] | None,
mm_processor_cache_gb: float | None,
mm_processor_cache_type: MMCacheType | None,
mm_shm_cache_max_object_size_mb: int | None,
mm_encoder_only: bool | None,
mm_encoder_tp_mode: MMEncoderTPMode | None,
mm_encoder_attn_backend: AttentionBackendEnum | str | None,
interleave_mm_strings: bool | None,
skip_mm_profiling: bool | None,
video_pruning_rate: float | None,
) -> None:
# Keep set served_model_name before maybe_model_redirect(self.model)
self.served_model_name = get_served_model_name(self.model, self.served_model_name)
self.model = maybe_model_redirect(self.model)
# The tokenizer is consistent with the model by default.
if self.tokenizer is None:
self.tokenizer = self.model
if self.tokenizer_revision is None:
self.tokenizer_revision = self.revision
self.tokenizer = maybe_model_redirect(self.tokenizer)
if isinstance(self.hf_config_path, str):
self.hf_config_path = maybe_model_redirect(self.hf_config_path)
if callable(self.hf_overrides):
hf_overrides_kw = {}
hf_overrides_fn = self.hf_overrides
dict_overrides: dict[str, Any] = {}
else:
# Separate dict overrides from flat ones
# We'll determine how to apply dict overrides after loading the config
hf_overrides_kw = {}
dict_overrides = {}
for key, value in self.hf_overrides.items():
if isinstance(value, dict):
dict_overrides[key] = value
else:
hf_overrides_kw[key] = value
hf_overrides_fn = None
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
if self.override_attention_dtype is not None and not current_platform.is_rocm():
warnings.warn(
"override-attention-dtype is set but not using ROCm platform",
stacklevel=2,
)
if self.enable_sleep_mode and not current_platform.is_sleep_mode_available():
raise ValueError("Sleep mode is not supported on current platform.")
hf_config = get_config(
self.hf_config_path or self.model,
self.trust_remote_code,
self.revision,
self.code_revision,
self.config_format,
hf_overrides_kw=hf_overrides_kw,
hf_overrides_fn=hf_overrides_fn,
)
hf_config = maybe_patch_hf_config_from_gguf(
self.model,
hf_config,
)
self.hf_config = hf_config
if dict_overrides:
self._apply_dict_overrides(hf_config, dict_overrides)
self.hf_text_config = self.draw_hf_text_config()
self.attention_chunk_size = getattr(self.hf_text_config, "attention_chunk_size", None)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=self.hf_token, revision=self.revision
)
self.model_arch_config = self.get_model_arch_config()
if self.convert == "mm_encoder_only":
logger.warning_once(
"`--convert mm_encoder_only` is deprecated and "
"will be removed in v0.15. "
"Please use --mm-encoder-only` instead."
)
mm_encoder_only = True
self.convert = "none"
architectures = self.architectures
registry = self.registry
is_generative_model = registry.is_text_generation_model(architectures, self)
is_pooling_model = registry.is_pooling_model(architectures, self)
self.runner_type = self._get_runner_type(architectures, self.runner)
self.convert_type = self._get_convert_type(architectures, self.runner_type, self.convert)
if self.runner_type == "generate" and not is_generative_model:
generate_converts = _RUNNER_CONVERTS["generate"]
if self.convert_type not in generate_converts:
# Currently we don't have any converters for generative models
raise ValueError("This model does not support `--runner generate`.")
if self.runner_type == "pooling" and not is_pooling_model:
pooling_converts = _RUNNER_CONVERTS["pooling"]
if self.convert_type not in pooling_converts:
convert_option = "<" + "|".join(pooling_converts) + ">"
raise ValueError(
"This model does not support `--runner pooling`. "
f"You can pass `--convert {convert_option} to adapt "
"it into a pooling model."
)
# Note: Initialize these attributes early because transformers fallback
# may fail to load dynamic modules in child processes
model_info, arch = registry.inspect_model_cls(architectures, self)
self._model_info = model_info
self._architecture = arch
logger.info("Resolved architecture: %s", arch)
# Init pooler config if needed
if self.runner_type == "pooling":
if self.pooler_config is None:
self.pooler_config = PoolerConfig()
base_config = get_pooling_config(self.model, self.revision)
if base_config is not None:
# Only set values that are not overridden by the user
for k, v in base_config.items():
if getattr(self.pooler_config, k) is None:
setattr(self.pooler_config, k, v)
default_seq_pooling_type = self._model_info.default_seq_pooling_type
if self.pooler_config.seq_pooling_type is None:
self.pooler_config.seq_pooling_type = default_seq_pooling_type
default_tok_pooling_type = self._model_info.default_tok_pooling_type
if self.pooler_config.tok_pooling_type is None:
self.pooler_config.tok_pooling_type = default_tok_pooling_type
self.dtype: torch.dtype = _get_and_verify_dtype(
self.model,
self.hf_config,
self.dtype,
is_pooling_model=self.runner_type == "pooling",
revision=self.revision,
)
self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
if self.is_encoder_decoder:
self.mm_processor_cache_gb = 0
logger.info("Encoder-decoder model detected, disabling mm processor cache.")
# Init multimodal config if needed
if self._model_info.supports_multimodal:
if mm_encoder_tp_mode == "data" and not self._model_info.supports_multimodal_encoder_tp_data:
logger.warning_once(
"This model does not support `--mm-encoder-tp-mode data`. "
"Falling back to `--mm-encoder-tp-mode weights`."
)
mm_encoder_tp_mode = "weights"
mm_config_kwargs = dict(
limit_per_prompt=limit_mm_per_prompt,
enable_mm_embeds=enable_mm_embeds,
media_io_kwargs=media_io_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
mm_processor_cache_gb=mm_processor_cache_gb,
mm_processor_cache_type=mm_processor_cache_type,
mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb,
mm_encoder_only=mm_encoder_only,
mm_encoder_tp_mode=mm_encoder_tp_mode,
mm_encoder_attn_backend=mm_encoder_attn_backend,
interleave_mm_strings=interleave_mm_strings,
skip_mm_profiling=skip_mm_profiling,
video_pruning_rate=video_pruning_rate,
)
mm_config_kwargs = {k: v for k, v in mm_config_kwargs.items() if v is not None}
self.multimodal_config = MultiModalConfig(**mm_config_kwargs)
# Multimodal GGUF models must use original repo for mm processing
if is_gguf(self.tokenizer) and self.is_multimodal_model:
raise ValueError(
"Loading a multimodal GGUF model needs to use original "
"tokenizer. Please specify the unquantized hf model's "
"repo name or path using the --tokenizer argument."
)
if self.disable_sliding_window:
# Set after get_and_verify_max_len to ensure that max_model_len
# can be correctly capped to sliding window size
self.hf_text_config.sliding_window = None
# Avoid running try_verify_and_update_config multiple times
self.config_updated = False
self._try_verify_and_update_model_config()
self._verify_quantization()
self._verify_cuda_graph()
self._verify_bnb_config()
"""
Scheduling components for vLLM-Omni.
"""
from .omni_ar_scheduler import OmniARScheduler
from .omni_generation_scheduler import OmniGenerationScheduler
from .output import OmniNewRequestData
__all__ = [
"OmniARScheduler",
"OmniGenerationScheduler",
"OmniNewRequestData",
]
from __future__ import annotations
import importlib
from collections import defaultdict
from dataclasses import asdict, dataclass
from time import time
from typing import Any
from vllm.compilation.cuda_graph import CUDAGraphStat
from vllm.distributed.kv_events import KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler
from vllm.v1.core.sched.utils import remove_all
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.perf import PerfStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm_omni.core.sched.output import OmniSchedulerOutput
from vllm_omni.distributed.omni_connectors.adapter import get_chunk, put_chunk
from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory
from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec
logger = init_logger(__name__)
@dataclass
class KVCacheTransferData:
request_id: str
layer_blocks: dict[str, Any]
block_ids: list[int]
metadata: dict[str, Any]
def to_dict(self) -> dict[str, Any]:
return asdict(self)
class OmniARScheduler(VLLMScheduler):
"""
OmniARScheduler: Scheduler for vLLM-Omni multimodal processing.
This scheduler extends vLLM's scheduler to support multimodal and
non-autoregressive processing with additional fields and methods
specific to vLLM-Omni.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Track requests that need KV cache transfer when finished
# Value is {"seq_len": int, "block_ids": list[int]}
self.requests_needing_kv_transfer: dict[str, dict[str, Any]] = {}
# Track requests waiting for KV transfer (blocks not freed yet)
self.waiting_for_transfer_free: set[str] = set()
# Track ACTIVE transfers (submitted to runner but not yet acked via kv_extracted_req_ids)
self.active_kv_transfers: set[str] = set()
# [Omni] Pre-parse KV transfer criteria
self.kv_transfer_criteria = self._get_kv_transfer_criteria()
# Track requests that have already triggered prefill transfer to avoid duplicates
self.transfer_triggered_requests: set[str] = set()
model_config = self.vllm_config.model_config
self.omni_connector = None
if model_config.async_chunk:
connector_config = model_config.stage_connector_config
connector_specs = ConnectorSpec(
name=connector_config.get("name", "SharedMemoryConnector"),
extra=connector_config.get("extra", {}),
)
self.omni_connector = OmniConnectorFactory.create_connector(connector_specs)
custom_process_next_stage_input_func = getattr(
self.vllm_config.model_config, "custom_process_next_stage_input_func", None
)
if custom_process_next_stage_input_func:
module_path, func_name = custom_process_next_stage_input_func.rsplit(".", 1)
module = importlib.import_module(module_path)
self.custom_process_next_stage_input_func = getattr(module, func_name)
self.stage_id = getattr(self.vllm_config.model_config, "stage_id", None)
def _get_kv_transfer_criteria(self) -> dict | None:
# Note: vllm_config is available in Scheduler after super().__init__
if not hasattr(self, "vllm_config"):
return None
omni_kv_config = getattr(self.vllm_config.model_config, "omni_kv_config", None)
if omni_kv_config:
if isinstance(omni_kv_config, dict):
return omni_kv_config.get("kv_transfer_criteria", None)
else:
return getattr(omni_kv_config, "kv_transfer_criteria", None)
return None
def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int]) -> bool:
"""
Check triggers and process side effects (marking transfer).
Returns True if request should be STOPPED.
Returns False if request should continue (even if transfer was triggered).
"""
if not self.kv_transfer_criteria:
return False
if request.request_id in self.waiting_for_transfer_free:
return False
criteria_type = self.kv_transfer_criteria.get("type")
# Universal duplicate check for once semantics
if request.request_id in self.transfer_triggered_requests:
return False
if criteria_type == "prefill_finished":
if request.num_computed_tokens >= request.num_prompt_tokens:
logger.debug(f"[Omni] Request {request.request_id} triggered prefill_finished transfer (Non-Stop)")
self.transfer_triggered_requests.add(request.request_id)
self._mark_request_for_kv_transfer(request.request_id, request.num_computed_tokens)
# Return False means "Do NOT stop the request" -> Continue Decoding
return False
elif criteria_type == "special_token":
target_token_id = self.kv_transfer_criteria.get("token_id")
if target_token_id is not None and target_token_id in new_token_ids:
logger.debug(f"[Omni] Request {request.request_id} triggered special_token criteria (Non-Stop)")
self.transfer_triggered_requests.add(request.request_id)
# Calculate precise snapshot length (trim to sentinel)
# Find the FIRST occurrence of the sentinel
try:
idx = new_token_ids.index(target_token_id)
# seq_len = tokens_before_this_step + idx + 1 (include sentinel)
# request.num_computed_tokens already includes ALL new_token_ids
# so we subtract (len(new_token_ids) - (idx + 1))
tokens_to_exclude = len(new_token_ids) - (idx + 1)
snapshot_len = request.num_computed_tokens - tokens_to_exclude
except ValueError:
snapshot_len = request.num_computed_tokens
# Trigger Transfer
self._mark_request_for_kv_transfer(request.request_id, snapshot_len)
# Do NOT stop request
return False
return False
def schedule(self) -> SchedulerOutput: # type: ignore[override]
scheduler_output = super().schedule()
try:
# Late import to avoid circulars in some launch modes
from .output import OmniNewRequestData
# Rewrap base NewRequestData entries with OmniNewRequestData,
# enriching with request-level payloads
new_list = []
for nr in scheduler_output.scheduled_new_reqs:
req_id = getattr(nr, "req_id", None)
request = self.requests.get(req_id) if req_id else None
# Build omni entry preserving all base fields
omni_nr = OmniNewRequestData(
req_id=nr.req_id,
external_req_id=(getattr(request, "external_req_id", None) if request else None),
prompt_token_ids=nr.prompt_token_ids,
mm_features=nr.mm_features,
sampling_params=nr.sampling_params,
pooling_params=nr.pooling_params,
block_ids=nr.block_ids,
num_computed_tokens=nr.num_computed_tokens,
lora_request=nr.lora_request,
# Enrich with omni payloads from the live request object
prompt_embeds=(getattr(request, "prompt_embeds", None) if request else None),
additional_information=(getattr(request, "additional_information", None) if request else None),
)
new_list.append(omni_nr)
scheduler_output.scheduled_new_reqs = new_list # type: ignore[assignment]
if self.omni_connector is not None:
get_chunk(self.omni_connector, scheduler_output)
# Add information about requests needing KV cache transfer
finished_reqs = self.get_finished_requests_needing_kv_transfer()
except Exception:
# If anything goes wrong, leave the original output unchanged
init_logger(__name__).exception("Failed to wrap scheduled_new_reqs with OmniNewRequestData")
finished_reqs = {}
# Wrap in omni scheduler output to carry transfer metadata.
base_fields = SchedulerOutput.__dataclass_fields__.keys()
base_data = {name: getattr(scheduler_output, name) for name in base_fields}
return OmniSchedulerOutput(
**base_data,
finished_requests_needing_kv_transfer=finished_reqs,
)
def update_from_output(
self,
scheduler_output: SchedulerOutput,
model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]:
sampled_token_ids = model_runner_output.sampled_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
kv_connector_output = model_runner_output.kv_connector_output
cudagraph_stats: CUDAGraphStat | None = model_runner_output.cudagraph_stats
perf_stats: PerfStats | None = None
if self.perf_metrics and self.perf_metrics.is_enabled():
perf_stats = self.perf_metrics.get_step_perf_stats_per_gpu(scheduler_output)
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats: KVConnectorStats | None = (
kv_connector_output.kv_connector_stats if kv_connector_output else None
)
if kv_connector_stats and self.connector:
kv_stats = self.connector.get_kv_connector_stats()
if kv_stats:
kv_connector_stats = kv_connector_stats.aggregate(kv_stats)
failed_kv_load_req_ids = None
if kv_connector_output and kv_connector_output.invalid_block_ids:
# These blocks contain externally computed tokens that failed to
# load. Identify affected requests and adjust their computed token
# count to trigger recomputation of the invalid blocks.
failed_kv_load_req_ids = self._handle_invalid_blocks(kv_connector_output.invalid_block_ids)
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best
# to avoid expensive operations inside the loop.
stopped_running_reqs: set[Request] = set()
stopped_preempted_reqs: set[Request] = set()
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
assert num_tokens_scheduled > 0
if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids:
# Skip requests that were recovered from KV load failure
continue
request = self.requests.get(req_id)
if request is None:
# The request is already finished. This can happen if the
# request is aborted while the model is executing it (e.g.,
# in pipeline parallelism).
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else []
scheduled_spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
if scheduled_spec_token_ids:
num_draft_tokens = len(scheduled_spec_token_ids)
num_accepted = len(generated_token_ids) - 1
num_rejected = num_draft_tokens - num_accepted
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens.
if request.num_computed_tokens > 0:
request.num_computed_tokens -= num_rejected
# If async scheduling, num_output_placeholders also includes
# the scheduled spec tokens count and so is similarly adjusted.
if request.num_output_placeholders > 0:
request.num_output_placeholders -= num_rejected
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted,
num_invalid_spec_tokens=scheduler_output.num_invalid_spec_tokens,
request_id=req_id,
)
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
pooler_output = pooler_outputs[req_index] if pooler_outputs else None
kv_transfer_params = None
status_before_stop = request.status
finish_reason = None
routed_experts = None
# Check for stop and update request status.
if new_token_ids:
new_token_ids, stopped = self._update_request_with_output(request, new_token_ids)
elif request.pooling_params and pooler_output is not None:
# Pooling stops as soon as there is output.
request.status = RequestStatus.FINISHED_STOPPED
stopped = True
# If criteria returns True, it means we must STOP the request.
# If criteria returns False, it might have triggered a background
# transfer (e.g. prefill finished / special token) but continues decoding.
if not stopped and self._process_kv_transfer_trigger(request, new_token_ids):
stopped = True
if stopped:
routed_experts = self._get_routed_experts(request)
# Capture finish_reason BEFORE _handle_stopped_request, which may
# reset the status to WAITING for streaming requests that continue.
finish_reason = request.get_finished_reason()
finished = self._handle_stopped_request(request)
if finished:
kv_transfer_params = self._free_request(request)
if status_before_stop == RequestStatus.RUNNING:
stopped_running_reqs.add(request)
else:
stopped_preempted_reqs.add(request)
# Extract sample logprobs if needed.
if request.sampling_params is not None and request.sampling_params.logprobs is not None and logprobs:
new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))
if new_token_ids and self.structured_output_manager.should_advance(request):
struct_output_request = request.structured_output_request
assert struct_output_request is not None
assert struct_output_request.grammar is not None
ok = struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
if not ok:
logger.warning(
"Unexpected: grammar rejected tokens %s for request %s.",
new_token_ids,
req_id,
)
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None or kv_transfer_params or stopped:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=finish_reason,
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
routed_experts=routed_experts,
num_nans_in_logits=request.num_nans_in_logits,
)
)
if self.omni_connector is not None:
custom_process_next_stage_input_func = self.custom_process_next_stage_input_func
put_chunk(self.omni_connector, pooler_output, request, custom_process_next_stage_input_func)
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
# Remove the stopped requests from the running and waiting queues.
if stopped_running_reqs:
self.running = remove_all(self.running, stopped_running_reqs)
if stopped_preempted_reqs:
# This is a rare case and unlikely to impact performance.
self.waiting.remove_requests(stopped_preempted_reqs)
# [Main] Handle failed KV load requests
if failed_kv_load_req_ids and not self.recompute_kv_load_failures:
requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids]
self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR)
for request in requests:
outputs[request.client_index].append(
EngineCoreOutput(
request_id=request.request_id,
new_token_ids=[],
finish_reason=request.get_finished_reason(),
events=request.take_events(),
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
)
)
# [Omni] Cleanup state for finished requests
for req in stopped_running_reqs:
if req.request_id not in self.waiting_for_transfer_free:
if req.request_id in self.transfer_triggered_requests:
self.transfer_triggered_requests.remove(req.request_id)
if req.request_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req.request_id)
# Same for preempted
for req in stopped_preempted_reqs:
if req.request_id not in self.waiting_for_transfer_free:
if req.request_id in self.transfer_triggered_requests:
self.transfer_triggered_requests.remove(req.request_id)
if req.request_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req.request_id)
# KV Connector: update state for finished KV Transfers.
if kv_connector_output:
self._update_from_kv_xfer_finished(kv_connector_output)
# collect KV cache events from KV cache manager
events = self.kv_cache_manager.take_events()
# collect KV cache events from connector
if self.connector is not None:
connector_events = self.connector.take_events()
if connector_events:
if events is None:
events = list(connector_events)
else:
events.extend(connector_events)
# publish collected KV cache events
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
# Create EngineCoreOutputs for all clients that have requests with
# outputs in this step.
engine_core_outputs = {client_index: EngineCoreOutputs(outputs=outs) for client_index, outs in outputs.items()}
finished_req_ids = self.finished_req_ids_dict
if finished_req_ids:
# Include ids of requests that finished since last outputs
# were sent.
for client_index, finished_set in finished_req_ids.items():
# Set finished request set in EngineCoreOutputs for this client.
if (eco := engine_core_outputs.get(client_index)) is not None:
eco.finished_requests = finished_set
else:
engine_core_outputs[client_index] = EngineCoreOutputs(finished_requests=finished_set)
finished_req_ids.clear()
if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats, cudagraph_stats, perf_stats)) is not None:
# Return stats to only one of the front-ends.
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
# We must return the stats even if there are no request
# outputs this step.
engine_core_outputs[0] = eco = EngineCoreOutputs()
eco.scheduler_stats = stats
# This is where we free blocks that were held for transfer
try:
kv_extracted_ids = getattr(model_runner_output, "kv_extracted_req_ids", None)
if kv_extracted_ids:
for req_id in kv_extracted_ids:
# Mark transfer as finished
if req_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req_id)
logger.debug(f"[Omni] KV Transfer finished for {req_id}")
if req_id in self.waiting_for_transfer_free:
# Now it's safe to free blocks
req = self.requests.get(req_id)
if req:
self.kv_cache_manager.free(req)
if req_id in self.requests:
del self.requests[req_id]
if req_id in self.transfer_triggered_requests:
self.transfer_triggered_requests.remove(req_id)
if req_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req_id)
logger.debug(f"Freed blocks for {req_id} after transfer extraction")
self.waiting_for_transfer_free.remove(req_id)
except Exception:
init_logger(__name__).exception("Failed to process finished transfer requests")
return engine_core_outputs
def _free_request(self, request: Request) -> dict[str, Any] | None:
# TODO(wzliu)! for offline mode, we should not end process until all data is transferred
"""Mark a request as finished and free its resources."""
# 1. Standard cleanup parts from base _free_request
delay_free_blocks = False
kv_xfer_params = None
if self.connector is not None:
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request)
request_id = request.request_id
self.finished_req_ids.add(request_id)
if self.finished_req_ids_dict is not None:
self.finished_req_ids_dict[request.client_index].add(request_id)
# 2. Omni Specific: Check if we need to transfer KV
if self._should_transfer_kv_for_request(request_id):
already_triggered = request_id in self.transfer_triggered_requests
is_active = request_id in self.active_kv_transfers
if already_triggered:
if is_active:
# It triggered but hasn't finished yet. We MUST wait.
logger.debug(f"[Omni] Request {request_id} finished but transfer is still ACTIVE. Waiting.")
self.waiting_for_transfer_free.add(request_id)
# We do NOT mark for transfer again, just wait.
kv_xfer_params = None # No new transfer params
return kv_xfer_params
else:
logger.debug(
f"[Omni] Request {request_id} finished and transfer no longer ACTIVE (extracted/acked). "
"Freeing immediately."
)
else:
self.waiting_for_transfer_free.add(request_id)
self._mark_request_for_kv_transfer(request_id, request.num_computed_tokens)
# Return KV transfer metadata so it propagates to RequestOutput
if request_id in self.requests_needing_kv_transfer:
transfer_data = self.requests_needing_kv_transfer[request_id]
kv_xfer_params = {
"past_key_values": transfer_data["block_ids"],
"kv_metadata": {"seq_len": transfer_data["seq_len"], "block_ids": transfer_data["block_ids"]},
}
# Also update request.additional_information for good measure
add_info = getattr(request, "additional_information", None)
# If additional_information is an AdditionalInformationPayload-like object,
# unpack list_data into a plain dict.
if (
add_info is not None
and hasattr(add_info, "entries")
and isinstance(getattr(add_info, "entries"), dict)
):
request.additional_information = {
k: getattr(v, "list_data")
for k, v in getattr(add_info, "entries").items()
if getattr(v, "list_data", None) is not None
}
add_info = request.additional_information
if add_info is None:
request.additional_information = {}
add_info = request.additional_information
if isinstance(add_info, dict):
add_info.update(kv_xfer_params)
return kv_xfer_params
# 3. Standard Freeing
if not delay_free_blocks:
self._free_blocks(request)
return kv_xfer_params
def _free_blocks(self, request: Request):
# Helper to match base class structure if not directly available
# VLLMScheduler has _free_blocks
super()._free_blocks(request)
def _mark_request_for_kv_transfer(self, req_id: str, seq_len: int) -> None:
"""Mark a request as needing KV cache transfer when it finishes."""
# Avoid duplicate marking (if already pending in queue)
if req_id in self.requests_needing_kv_transfer:
return
if self._should_transfer_kv_for_request(req_id):
# [Omni] Get block IDs from KVCacheManager
try:
block_ids_tuple = self.kv_cache_manager.get_block_ids(req_id)
if block_ids_tuple and len(block_ids_tuple) > 0:
block_ids = block_ids_tuple[0]
# [Omni] Fix: Truncate blocks to match seq_len snapshot
# We need to know block_size. Usually in self.cache_config.block_size
# Note: vllm_config might not be directly available, check scheduler_config or cache_config
if hasattr(self, "cache_config") and hasattr(self.cache_config, "block_size"):
block_size = self.cache_config.block_size
elif hasattr(self, "scheduler_config") and hasattr(
self.scheduler_config, "block_size"
): # Some versions
block_size = self.scheduler_config.block_size
else:
raise ValueError("Block size not found in cache_config or scheduler_config")
# ceil(seq_len / block_size)
num_blocks = (seq_len + block_size - 1) // block_size
if len(block_ids) > num_blocks:
logger.debug(
f"[Omni] Truncating blocks for {req_id} from {len(block_ids)} "
f"to {num_blocks} (seq_len={seq_len})"
)
block_ids = block_ids[:num_blocks]
else:
block_ids = []
except Exception as e:
init_logger(__name__).warning(f"Failed to get block IDs for {req_id}: {e}")
block_ids = []
self.requests_needing_kv_transfer[req_id] = {"seq_len": seq_len, "block_ids": block_ids}
logger.debug(f"Marked request {req_id} for KV cache transfer (len={seq_len}, blocks={len(block_ids)})")
def _should_transfer_kv_for_request(self, req_id: str) -> bool:
"""Determine if a request should trigger KV cache transfer."""
need_send = False
# Try to read from vLLM Config (where YAML config is typically loaded)
# Check for omni_kv_config attribute
omni_kv_config = getattr(self.vllm_config.model_config, "omni_kv_config", None)
if omni_kv_config:
# omni_kv_config could be an object or a dict
if isinstance(omni_kv_config, dict):
need_send = omni_kv_config.get("need_send_cache", False)
else:
need_send = getattr(omni_kv_config, "need_send_cache", False)
return need_send
def has_requests(self) -> bool:
"""Check if there are any requests to process, including KV transfers."""
# [Omni] Also check for pending KV transfers
if self.requests_needing_kv_transfer or self.active_kv_transfers or self.waiting_for_transfer_free:
return True
return super().has_requests()
def has_finished_requests(self) -> bool:
"""Check if there are any finished requests (including those needing KV transfer)."""
if self.requests_needing_kv_transfer or self.active_kv_transfers or self.waiting_for_transfer_free:
return True
return super().has_finished_requests()
def has_unfinished_requests(self) -> bool:
"""Check if there are any unfinished requests (including those needing KV transfer)."""
# [Omni] Also check for pending KV transfers to ensure the engine loop continues
# MUST verify waiting_for_transfer_free and active_kv_transfers
# Otherwise engine loop might exit before transfer Ack is received.
if self.requests_needing_kv_transfer or self.active_kv_transfers or self.waiting_for_transfer_free:
return True
return super().has_unfinished_requests()
def get_finished_requests_needing_kv_transfer(self) -> dict[str, dict]:
"""Get and clear the list of requests needing KV cache transfer.
Returns dict: {req_id: {"seq_len": int, "block_ids": list[int]}}
"""
requests = self.requests_needing_kv_transfer.copy()
# Mark these requests as ACTIVE (sent to runner)
self.active_kv_transfers.update(requests.keys())
self.requests_needing_kv_transfer.clear()
return requests
import time
from collections import defaultdict
from vllm.compilation.cuda_graph import CUDAGraphStat
from vllm.distributed.kv_events import KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.request_queue import create_request_queue
from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler
from vllm.v1.core.sched.utils import remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.perf import PerfStats
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData
from vllm_omni.distributed.omni_connectors.adapter import get_chunk_for_generation
from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory
from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec
from vllm_omni.outputs import OmniModelRunnerOutput
class OmniGenerationScheduler(VLLMScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
model_config = self.vllm_config.model_config
self.omni_connector = None
if model_config.async_chunk:
connector_config = model_config.stage_connector_config
connector_specs = ConnectorSpec(
name=connector_config.get("name", "SharedMemoryConnector"),
extra=connector_config.get("extra", {}),
)
self.omni_connector = OmniConnectorFactory.create_connector(connector_specs)
self.stage_id = getattr(self.vllm_config.model_config, "stage_id", None)
def schedule(self) -> SchedulerOutput:
"""Diffusion fast path:
- Feed all input tokens of the request at once
(if 0, allocate 1 placeholder token).
- If the token budget cannot be satisfied at once, fall back to the
default vLLM scheduling.
"""
token_budget = self.max_num_scheduled_tokens
scheduled_timestamp = time.monotonic()
scheduled_new_reqs: list[Request] = []
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {}
scheduled_running_reqs: list[Request] = []
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
scheduled_encoder_inputs: dict[str, list[int]] = {}
cached_prompt_token_ids: dict[str, list[int]] = {}
# Temporary queue: preserve waiting order, do not disturb non-diffusion requests
skipped_waiting_requests = create_request_queue(self.policy)
req_index = 0
# OMNI: Track requests that are already finished (e.g., marked by connector)
# These should be removed from running and not scheduled
already_finished_reqs: set[Request] = set()
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
if self.omni_connector is not None:
get_chunk_for_generation(self.omni_connector, request)
# OMNI: Skip requests that are already finished or not in self.requests
# This can happen when connector marks request as finished
if request.status == RequestStatus.FINISHED_STOPPED or request.request_id not in self.requests:
already_finished_reqs.add(request)
req_index += 1
continue
num_computed_tokens = request.num_computed_tokens
required_tokens = max(len(request.prompt_token_ids) - num_computed_tokens, 1)
num_new_tokens = min(required_tokens, token_budget)
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens,
)
if new_blocks is None:
# Allocation failed (e.g., VRAM pressure); stop fast path and
# fall back to default scheduling
# Put the current request back to the head of the waiting queue
# Note: the original queue order is preserved
break
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
req_to_new_blocks[request.request_id] = new_blocks
num_scheduled_tokens[request.request_id] = num_new_tokens
cached_prompt_token_ids[request.request_id] = request.prompt_token_ids
token_budget -= num_new_tokens
scheduled_running_reqs.append(request)
req_index += 1
# OMNI: Remove already finished requests from running queue
if already_finished_reqs:
self.running = remove_all(self.running, already_finished_reqs)
# Fast path selection and scheduling (treat all as diffusion requests,
# independent of pooling_params)
while self.waiting and token_budget > 0 and len(self.running) < self.max_num_running_reqs:
request = self.waiting.peek_request()
if self.omni_connector is not None:
get_chunk_for_generation(self.omni_connector, request)
# OMNI: Skip requests that are already finished or not in self.requests
# This can happen when connector marks request as finished
if request.status == RequestStatus.FINISHED_STOPPED or request.request_id not in self.requests:
# Pop the finished request from waiting queue and don't schedule it
self.waiting.pop_request()
continue
# Uniformly treat as diffusion. A feature flag can be added later
# via config or request tag.
# Allocate all input tokens for the request in one shot
# (allocate 1 placeholder if zero)
required_tokens = max(len(request.prompt_token_ids), 1)
num_new_tokens = min(required_tokens, token_budget)
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens,
)
if new_blocks is None:
# Allocation failed (e.g., VRAM pressure); stop fast path and
# fall back to default scheduling
# Put the current request back to the head of the waiting queue
# Note: the original queue order is preserved
break
# Officially schedule this request
request = self.waiting.pop_request()
self.running.append(request)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED, scheduled_timestamp)
req_to_new_blocks[request.request_id] = new_blocks
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
scheduled_new_reqs.append(request)
# Return skipped waiting requests
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# If fast path scheduled none, fall back to the original scheduling
if not num_scheduled_tokens:
return super().schedule()
# Compute common prefix blocks (aligned with v1)
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
# Assemble SchedulerOutput (align with v0.14.0)
if self.use_v2_model_runner:
# No resumed reqs in fast path; pass prefill_token_ids for new reqs.
new_reqs_data = [
OmniNewRequestData.from_request(
req,
req_to_new_blocks[req.request_id].get_block_ids(),
getattr(req, "_all_token_ids", None),
)
for req in scheduled_new_reqs
]
else:
new_reqs_data = [
OmniNewRequestData.from_request(req, req_to_new_blocks[req.request_id].get_block_ids())
for req in scheduled_new_reqs
]
# No running/resumed reqs scheduled in our fast path
cached_reqs_data = self._make_cached_request_data(
running_reqs=scheduled_running_reqs,
resumed_reqs=[],
num_scheduled_tokens=num_scheduled_tokens,
spec_decode_tokens=scheduled_spec_decode_tokens,
req_to_new_blocks=req_to_new_blocks,
)
cached_reqs_data = OmniCachedRequestData(
req_ids=cached_reqs_data.req_ids,
resumed_req_ids=cached_reqs_data.resumed_req_ids,
new_token_ids=cached_reqs_data.new_token_ids,
all_token_ids=cached_reqs_data.all_token_ids,
new_block_ids=cached_reqs_data.new_block_ids,
num_computed_tokens=cached_reqs_data.num_computed_tokens,
num_output_tokens=cached_reqs_data.num_output_tokens,
prompt_token_ids=cached_prompt_token_ids,
)
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
preempted_req_ids=set(),
)
# Record the request ids scheduled in this step (v0.14.0 behavior).
self.prev_step_scheduled_req_ids.clear()
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
# KVTransfer: package metadata
if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
# EC Connector: package metadata
if self.ec_connector is not None:
ec_meta = self.ec_connector.build_connector_meta(scheduler_output)
scheduler_output.ec_connector_metadata = ec_meta
# Update internal state (advance num_computed_tokens, free encoder inputs,
# etc.)
self._update_after_schedule(scheduler_output)
try:
# Rewrap base NewRequestData entries with OmniNewRequestData,
# enriching with request-level payloads
new_list = []
for nr in scheduler_output.scheduled_new_reqs:
req_id = getattr(nr, "req_id", None)
request = self.requests.get(req_id) if req_id else None
# Build omni entry preserving all base fields
omni_nr = OmniNewRequestData(
req_id=nr.req_id,
external_req_id=(getattr(request, "external_req_id", None) if request else None),
prompt_token_ids=nr.prompt_token_ids,
mm_features=nr.mm_features,
sampling_params=nr.sampling_params,
pooling_params=nr.pooling_params,
block_ids=nr.block_ids,
num_computed_tokens=nr.num_computed_tokens,
lora_request=nr.lora_request,
# Enrich with omni payloads from the live request object
prompt_embeds=(getattr(request, "prompt_embeds", None) if request else None),
additional_information=(getattr(request, "additional_information", None) if request else None),
)
new_list.append(omni_nr)
scheduler_output.scheduled_new_reqs = new_list # type: ignore[assignment]
except Exception:
# If anything goes wrong, leave the original output unchanged
init_logger(__name__).exception("Failed to wrap scheduled_new_reqs with OmniNewRequestData")
return scheduler_output
"""
Scheduler for the diffusion model.
This scheduler is modified to stop the request immediately for the diffusion model.
This is because the diffusion model can generate the final image/audio in one step.
Note: This is just a minimal modification to the original scheduler,
and there should be some further efforts to optimize the scheduler.
The original scheduler is still used for the AR model.
"""
def update_from_output(
self,
scheduler_output: SchedulerOutput,
model_runner_output: OmniModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]:
"""Update the scheduler state based on the model runner output.
This method is modified to stop the request immediately for the diffusion model.
"""
sampled_token_ids = model_runner_output.sampled_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
kv_connector_output = model_runner_output.kv_connector_output
cudagraph_stats: CUDAGraphStat | None = model_runner_output.cudagraph_stats
perf_stats: PerfStats | None = None
if self.perf_metrics and self.perf_metrics.is_enabled():
perf_stats = self.perf_metrics.get_step_perf_stats_per_gpu(scheduler_output)
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats: KVConnectorStats | None = (
kv_connector_output.kv_connector_stats if kv_connector_output else None
)
# Merge connector-side stats (align with v0.14.0)
if kv_connector_stats and self.connector:
kv_stats = self.connector.get_kv_connector_stats()
if kv_stats:
kv_connector_stats = kv_connector_stats.aggregate(kv_stats)
failed_kv_load_req_ids = None
if kv_connector_output and getattr(kv_connector_output, "invalid_block_ids", None):
failed_kv_load_req_ids = self._handle_invalid_blocks(kv_connector_output.invalid_block_ids)
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best
# to avoid expensive operations inside the loop.
stopped_running_reqs: set[Request] = set()
stopped_preempted_reqs: set[Request] = set()
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
assert num_tokens_scheduled > 0
if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids:
# Skip requests that were recovered from KV load failure
continue
request = self.requests.get(req_id)
if request is None:
# The request is already finished. This can happen if the
# request is aborted while the model is executing it (e.g.,
# in pipeline parallelism).
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index] if sampled_token_ids else []
scheduled_spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
if scheduled_spec_token_ids:
num_draft_tokens = len(scheduled_spec_token_ids)
num_accepted = len(generated_token_ids) - 1
num_rejected = num_draft_tokens - num_accepted
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens.
if request.num_computed_tokens > 0:
request.num_computed_tokens -= num_rejected
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted,
)
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
pooler_output = pooler_outputs[req_index] if pooler_outputs else None
status_before_stop = request.status
finish_reason = None
routed_experts = None
# Diffusion request: completes in one step; mark finished and free resources
if request.status == RequestStatus.FINISHED_STOPPED or (
self.omni_connector is None and request.num_computed_tokens >= request.num_prompt_tokens
):
request.status = RequestStatus.FINISHED_STOPPED
# Optional: set a stop_reason for front-end clarity
# (does not affect protocol)
request.stop_reason = request.stop_reason # or "generation_done"
stopped = True
if stopped:
routed_experts = self._get_routed_experts(request)
finish_reason = request.get_finished_reason()
finished = self._handle_stopped_request(request)
if finished:
kv_transfer_params = self._free_request(request)
if status_before_stop == RequestStatus.RUNNING:
stopped_running_reqs.add(request)
else:
stopped_preempted_reqs.add(request)
# Extract sample logprobs if needed.
if request.sampling_params is not None and request.sampling_params.logprobs is not None and logprobs:
new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))
if new_token_ids and self.structured_output_manager.should_advance(request):
# NOTE: structured_output_request should not be None if
# use_structured_output, we have check above, so safe to ignore
# type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] # noqa: E501
req_id, new_token_ids
)
# spec_token_ids comes from the model runner output
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None or kv_transfer_params or stopped:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=finish_reason,
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
routed_experts=routed_experts,
num_nans_in_logits=request.num_nans_in_logits,
)
)
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
# Remove the stopped requests from the running and waiting queues.
if stopped_running_reqs:
self.running = remove_all(self.running, stopped_running_reqs)
if stopped_preempted_reqs:
# This is a rare case and unlikely to impact performance.
self.waiting.remove_requests(stopped_preempted_reqs)
# KV Connector: update state for finished KV Transfers.
if kv_connector_output:
self._update_from_kv_xfer_finished(kv_connector_output)
# Collect and publish KV cache events (align with v0.14.0)
events = self.kv_cache_manager.take_events()
if self.connector is not None:
connector_events = self.connector.take_events()
if connector_events:
if events is None:
events = list(connector_events)
else:
events.extend(connector_events)
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
# Create EngineCoreOutputs for all clients that have requests with
# outputs in this step.
engine_core_outputs = {client_index: EngineCoreOutputs(outputs=outs) for client_index, outs in outputs.items()}
finished_req_ids = self.finished_req_ids_dict
if finished_req_ids:
# Include ids of requests that finished since last outputs
# were sent.
for client_index, finished_set in finished_req_ids.items():
# Set finished request set in EngineCoreOutputs for this client.
if (eco := engine_core_outputs.get(client_index)) is not None:
eco.finished_requests = finished_set
else:
engine_core_outputs[client_index] = EngineCoreOutputs(finished_requests=finished_set)
finished_req_ids.clear()
if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats, cudagraph_stats, perf_stats)) is not None:
# Return stats to only one of the front-ends.
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
# We must return the stats even if there are no request
# outputs this step.
engine_core_outputs[0] = eco = EngineCoreOutputs()
eco.scheduler_stats = stats
return engine_core_outputs
from dataclasses import dataclass, field
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
from vllm.v1.request import Request
from vllm_omni.engine import AdditionalInformationPayload, PromptEmbedsPayload
@dataclass
class OmniNewRequestData(NewRequestData):
"""New request data for omni models with embeddings support.
Extends NewRequestData to include prompt embeddings and additional
information for direct transfer between pipeline stages.
Args:
prompt_embeds: Optional serialized prompt embeddings payload
additional_information: Optional serialized additional information
dictionary containing tensors or lists
"""
# Optional serialized prompt embeddings
prompt_embeds: PromptEmbedsPayload | None = None
# Optional external request ID for tracking
external_req_id: str | None = None
# Optional serialized additional information
additional_information: AdditionalInformationPayload | None = None
@classmethod
def from_request(
cls,
request: Request,
block_ids: tuple[list[int], ...],
prefill_token_ids: list[int] | None = None,
) -> "OmniNewRequestData":
"""Create OmniNewRequestData from a Request object.
Args:
request: Request object to convert
block_ids: Tuple of block ID lists for KV cache allocation
Returns:
OmniNewRequestData instance with data from the request
"""
return cls(
req_id=request.request_id,
external_req_id=request.external_req_id,
prompt_token_ids=request.prompt_token_ids,
mm_features=request.mm_features,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
prompt_embeds=request.prompt_embeds,
prefill_token_ids=prefill_token_ids,
additional_information=request.additional_information,
)
@dataclass
class OmniCachedRequestData(CachedRequestData):
"""Cached request data for omni models with embeddings support.
Args:
prompt_token_ids: Mapping from request ID to list of prompt token IDs
"""
prompt_token_ids: dict[str, list[int]]
@dataclass
class OmniSchedulerOutput(SchedulerOutput):
"""Scheduler output with omni-specific transfer metadata."""
finished_requests_needing_kv_transfer: dict[str, dict] = field(default_factory=dict)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, TypeVar
import torch
from vllm_omni.platforms import current_omni_platform
class AttentionBackend(ABC):
"""Abstract class for diffusion attention backends."""
accept_output_buffer: bool = False
@classmethod
def supports_attention_mask(cls) -> bool:
return False
@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_impl_cls() -> type["AttentionImpl"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_supported_head_sizes() -> list[int]:
"""Get the list of supported head sizes for this backend."""
raise NotImplementedError
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
supported_head_sizes = cls.get_supported_head_sizes()
return (not supported_head_sizes) or head_size in supported_head_sizes
@dataclass
class AttentionMetadata:
attn_mask: torch.Tensor | None = None
joint_attn_mask: torch.Tensor | None = None
# a joint mask for the joint query, key, and value, depends the joint_strategy
joint_query: torch.Tensor | None = None
# a replicated tensor among processes appended to the front or rear of query, depends the joint_strategy
joint_key: torch.Tensor | None = None
# a replicated tensor among processes appended to the front or rear of key, depends the joint_strategy
joint_value: torch.Tensor | None = None
# a replicated tensor among processes appended to the front or rear of value, depends the joint_strategy
joint_strategy: str = "front"
# the strategy to joint the query, key, and value, can be "front" or "rear"
T = TypeVar("T", bound=AttentionMetadata)
class AttentionImpl(ABC, Generic[T]):
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
softmax_scale: float,
causal: bool = False,
num_kv_heads: int | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
raise NotImplementedError
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: T | None = None,
) -> torch.Tensor:
"""Dispatch to platform-specific forward implementation."""
if current_omni_platform.is_rocm():
return self.forward_hip(query, key, value, attn_metadata)
elif current_omni_platform.is_cuda():
return self.forward_cuda(query, key, value, attn_metadata)
elif current_omni_platform.is_npu():
return self.forward_npu(query, key, value, attn_metadata)
elif current_omni_platform.is_xpu():
return self.forward_xpu(query, key, value, attn_metadata)
else:
raise NotImplementedError(f"No forward implementation for platform: {current_omni_platform}")
def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: T | None = None,
) -> torch.Tensor:
raise NotImplementedError
def forward_npu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: T | None = None,
) -> torch.Tensor:
raise NotImplementedError
def forward_xpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: T | None = None,
) -> torch.Tensor:
raise NotImplementedError
def forward_hip(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: T | None = None,
) -> torch.Tensor:
# By default, HIP ops are compatible with CUDA ops.
return self.forward_cuda(query, key, value, attn_metadata)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.logger import init_logger
from vllm_omni.diffusion.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
)
logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def supports_attention_mask(cls) -> bool:
return True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [64, 96, 128, 192, 256]
@staticmethod
def get_name() -> str:
return "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> type["FlashAttentionImpl"]:
return FlashAttentionImpl
class FlashAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
softmax_scale: float,
causal: bool = False,
num_kv_heads: int | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.num_heads = num_heads
self.causal = causal
self.softmax_scale = softmax_scale
def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata = None,
) -> torch.Tensor:
"""CUDA/ROCm flash attention implementation."""
# Import flash attention functions with fallback chain from utils/fa.py
# FA3 (fa3_fwd_interface) -> FA3 (flash_attn_interface) -> FA2 (flash_attn)
from vllm_omni.diffusion.attention.backends.utils.fa import (
HAS_FLASH_ATTN,
_pad_input,
_unpad_input,
_upad_input,
flash_attn_func,
flash_attn_varlen_func,
)
if not HAS_FLASH_ATTN:
raise ImportError(
"FlashAttentionBackend requires Flash Attention. "
"Please install one of: fa3-fwd, flash-attention, or flash-attn. "
"Otherwise, use SDPA backend by setting DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA"
)
query_length = query.size(1)
attention_mask = attn_metadata.attn_mask if attn_metadata is not None else None
# Contains at least one padding token in the sequence
if attention_mask is not None and torch.any(~attention_mask):
assert attention_mask.ndim == 2, "attention_mask must be 2D, (batch_size, seq_len)"
q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input(
query, key, value, attention_mask, query_length, _unpad_input
)
out_unpad = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
**{
"causal": self.causal,
"softmax_scale": self.softmax_scale,
},
)
if isinstance(out_unpad, tuple):
out_unpad = out_unpad[0]
out = _pad_input(out_unpad, indices_q, query.size(0), query_length)
else:
out = flash_attn_func(
query,
key,
value,
causal=self.causal,
softmax_scale=self.softmax_scale,
)
# FA3 may return (out, lse) tuple, FA2 returns just out
if isinstance(out, tuple):
out = out[0]
return out
def forward_npu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata = None,
) -> torch.Tensor:
"""NPU attention implementation using mindiesd."""
try:
from mindiesd import attention_forward
except ImportError:
raise ImportError(
"FlashAttentionBackend NPU implementation requires MindIE-SD. "
"Please install MindIE-SD to enable NPU attention support. "
"For installation details, see https://gitcode.com/Ascend/MindIE-SD"
"Otherwise, use SDPA backend by setting DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA"
)
attention_mask = attn_metadata.attn_mask if attn_metadata else None
output = attention_forward(
query,
key,
value,
attn_mask=attention_mask,
opt_mode="manual",
op_type="fused_attn_score",
layout="BNSD",
)
return output
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Diffusion attention backend registry.
This module provides an enum-based registry for diffusion attention backends,
similar to vLLM's AttentionBackendEnum. Each backend registers its class path,
and platforms can override or extend backends using register_backend().
"""
from collections.abc import Callable
from enum import Enum, EnumMeta
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING:
from vllm_omni.diffusion.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
class _DiffusionBackendEnumMeta(EnumMeta):
"""Metaclass for DiffusionAttentionBackendEnum to provide better error messages."""
def __getitem__(cls, name: str) -> "DiffusionAttentionBackendEnum":
"""Get backend by name with helpful error messages."""
try:
return super().__getitem__(name) # type: ignore[return-value]
except KeyError:
members = list(cls.__members__.keys())
valid_backends = ", ".join(members)
raise ValueError(
f"Unknown diffusion attention backend: '{name}'. Valid options are: {valid_backends}"
) from None
class DiffusionAttentionBackendEnum(Enum, metaclass=_DiffusionBackendEnumMeta):
"""Enumeration of all supported diffusion attention backends.
The enum value is the default class path, but this can be overridden
at runtime using register_backend().
To get the actual backend class (respecting overrides), use:
backend.get_class()
Example:
# Get backend class
backend = DiffusionAttentionBackendEnum.FLASH_ATTN
backend_cls = backend.get_class()
# Register custom backend
@register_diffusion_backend(DiffusionAttentionBackendEnum.CUSTOM)
class MyCustomBackend:
...
"""
# Common backends (available on most platforms)
FLASH_ATTN = "vllm_omni.diffusion.attention.backends.flash_attn.FlashAttentionBackend"
TORCH_SDPA = "vllm_omni.diffusion.attention.backends.sdpa.SDPABackend"
SAGE_ATTN = "vllm_omni.diffusion.attention.backends.sage_attn.SageAttentionBackend"
def get_path(self, include_classname: bool = True) -> str:
"""Get the class path for this backend (respects overrides).
Returns:
The fully qualified class path string
Raises:
ValueError: If backend has empty path and is not registered
"""
path = _DIFFUSION_ATTN_OVERRIDES.get(self, self.value)
if not path:
raise ValueError(
f"Backend {self.name} must be registered before use. "
f"Use register_diffusion_backend(DiffusionAttentionBackendEnum.{self.name}, "
f"'your.module.YourClass')"
)
if not include_classname:
path = path.rsplit(".", 1)[0]
return path
def get_class(self) -> "type[AttentionBackend]":
"""Get the backend class (respects overrides).
Returns:
The backend class
Raises:
ImportError: If the backend class cannot be imported
ValueError: If backend has empty path and is not registered
"""
return resolve_obj_by_qualname(self.get_path())
def is_overridden(self) -> bool:
"""Check if this backend has been overridden.
Returns:
True if the backend has a registered override
"""
return self in _DIFFUSION_ATTN_OVERRIDES
def clear_override(self) -> None:
"""Clear any override for this backend, reverting to the default."""
_DIFFUSION_ATTN_OVERRIDES.pop(self, None)
# Override registry
_DIFFUSION_ATTN_OVERRIDES: dict[DiffusionAttentionBackendEnum, str] = {}
def register_diffusion_backend(
backend: DiffusionAttentionBackendEnum,
class_path: str | None = None,
) -> Callable[[type], type]:
"""Register or override a diffusion backend implementation.
Args:
backend: The DiffusionAttentionBackendEnum member to register
class_path: Optional class path. If not provided and used as
decorator, will be auto-generated from the class.
Returns:
Decorator function if class_path is None, otherwise a no-op
Examples:
# Override an existing backend
@register_diffusion_backend(DiffusionAttentionBackendEnum.FLASH_ATTN)
class MyCustomFlashAttn:
...
# Override an existing backend (e.g., ASCEND_ATTN)
@register_diffusion_backend(DiffusionAttentionBackendEnum.ASCEND_ATTN)
class CustomAscendAttentionBackend:
...
# Direct registration
register_diffusion_backend(
DiffusionAttentionBackendEnum.CUSTOM,
"my.module.MyCustomBackend"
)
"""
def decorator(cls: type) -> type:
_DIFFUSION_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
return cls
if class_path is not None:
_DIFFUSION_ATTN_OVERRIDES[backend] = class_path
return lambda x: x
return decorator
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Jiarui Fang.
# Adapted from https://github.com/feifeibear/long-context-attention
# test if flash_attn (FA2) is available
try:
import flash_attn # noqa: F401
from flash_attn.flash_attn_interface import _flash_attn_forward # noqa: F401
HAS_FLASH_ATTN = True
except (ImportError, ModuleNotFoundError):
HAS_FLASH_ATTN = False
# FA3 detection: try multiple sources (forward only, no backward needed for inference)
# Source 1: flash_attn_interface (from flash-attention source build)
# Source 2: fa3_fwd_interface (from fa3-fwd PyPI package, supports Ampere/Ada/Hopper)
# Note: FA3 high-level API may or may not return softmax_lse depending on version.
# For Ring Attention which requires LSE, we fall back to low-level API if needed.
HAS_FA3 = False
fa3_fwd_func = None # Low-level forward function (_flash_attn_forward)
fa3_attn_func = None # High-level attention function (flash_attn_func)
# Try flash_attn_interface first (from flash-attention source build)
try:
from flash_attn_interface import _flash_attn_forward as fa3_fwd_func # noqa: F401
from flash_attn_interface import flash_attn_func as fa3_attn_func # noqa: F401
HAS_FA3 = True
except (ImportError, ModuleNotFoundError):
pass
# Fallback: try fa3_fwd_interface (PyPI package, supports Ampere/Ada/Hopper)
if not HAS_FA3:
try:
from fa3_fwd_interface import _flash_attn_forward as fa3_fwd_func # noqa: F401
from fa3_fwd_interface import flash_attn_func as fa3_attn_func # noqa: F401
HAS_FA3 = True
except (ImportError, ModuleNotFoundError):
pass
# Legacy aliases for backward compatibility
HAS_FLASH_ATTN_HOPPER = HAS_FA3
flash_attn_forward_hopper = fa3_fwd_func
flash3_attn_func = fa3_attn_func
try:
from flashinfer.prefill import single_prefill_with_kv_cache # noqa: F401
HAS_FLASHINFER = True
except (ImportError, ModuleNotFoundError):
HAS_FLASHINFER = False
try:
import aiter # noqa: F401
from aiter import flash_attn_func as flash_attn_func_aiter # noqa: F401
HAS_AITER = True
except (ImportError, ModuleNotFoundError):
HAS_AITER = False
try:
import sageattention # noqa: F401
HAS_SAGE_ATTENTION = True
except (ImportError, ModuleNotFoundError):
HAS_SAGE_ATTENTION = False
try:
import spas_sage_attn # noqa: F401
HAS_SPARSE_SAGE_ATTENTION = True
except (ImportError, ModuleNotFoundError):
HAS_SPARSE_SAGE_ATTENTION = False
try:
import torch_npu # noqa: F401
HAS_NPU = True
except (ImportError, ModuleNotFoundError):
HAS_NPU = False
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Jiarui Fang.
# Adapted from https://github.com/feifeibear/long-context-attention
import math
import torch
from .ring_globals import (
HAS_AITER,
HAS_FA3,
HAS_FLASH_ATTN,
HAS_FLASHINFER,
fa3_fwd_func,
)
_scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_flash_attention
_scaled_dot_product_efficient_attention = torch.ops.aten._scaled_dot_product_efficient_attention
try:
import torch_musa # noqa: F401
_scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_attention_flash_musa
_scaled_dot_product_efficient_attention = None
except ModuleNotFoundError:
pass
if HAS_AITER:
from aiter import flash_attn_func as flash_attn_func_aiter
if HAS_FLASH_ATTN:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
if HAS_FLASHINFER:
from flashinfer.prefill import single_prefill_with_kv_cache
_LOG2_E = math.log2(math.e)
def pytorch_attn_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p=0.0,
softmax_scale=None,
causal=True,
window_size=(-1, -1),
softcap=None,
alibi_slopes=None,
return_softmax=False,
op_type="efficient",
):
assert op_type in ["flash", "efficient"], f"Invalid op_type: {op_type}"
"""
q shape (bs, seqlen, nhead, hs)
k shape (bs, seqlen, nhead, hs)
v shape (bs, seqlen, nhead, hs)
"""
# Fallback logic: Flash Attention does not support float32.
# If op_type is 'flash' but dtype is float32, force 'efficient'.
if op_type == "flash" and q.dtype == torch.float32:
op_type = "efficient"
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if op_type == "flash":
out, lse = _scaled_dot_product_flash_attention(
q,
k,
v,
dropout_p=dropout_p,
is_causal=causal,
scale=softmax_scale,
)[:2]
elif op_type == "efficient":
out, lse = _scaled_dot_product_efficient_attention(
q,
k,
v,
attn_bias=None,
compute_log_sumexp=True,
dropout_p=dropout_p,
is_causal=causal,
scale=softmax_scale,
)[:2]
else:
raise ValueError(f"Invalid op_type: {op_type}")
out = out.transpose(1, 2)
lse = lse.to(q.dtype)
return out, lse
def flash_attn_forward(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=None,
alibi_slopes=None,
return_softmax=False,
):
assert HAS_FLASH_ATTN, "FlashAttention is not available"
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if flash_attn.__version__ < "2.6.3":
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
q,
k,
v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax,
)
else:
block_out, block_lse, _, _ = _flash_attn_forward(
q,
k,
v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax,
)
return block_out, block_lse
def fa3_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax):
"""FA3 forward pass for inference.
FA3 supports Ampere, Ada, and Hopper GPUs. Dropout is ignored since FA3 is inference-only.
Uses low-level API (_flash_attn_forward) which always returns softmax_lse,
required for Ring Attention's correct accumulation.
"""
assert HAS_FA3, "FA3 is not available"
assert fa3_fwd_func is not None, "FA3 low-level API (fa3_fwd_func) not available"
# Low-level API always returns (out, softmax_lse, S_dmask, rng_state)
out, softmax_lse, *_ = fa3_fwd_func(
q,
k,
v,
softmax_scale=softmax_scale,
causal=causal,
window_size_left=window_size[0] if window_size else -1,
window_size_right=window_size[1] if window_size else -1,
softcap=softcap if softcap else 0.0,
)
return out, softmax_lse
# Legacy alias for backward compatibility
flash_attn3_func_forward = fa3_forward
def flash_attn_forward_aiter(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=None,
alibi_slopes=None,
return_softmax=False,
):
assert HAS_AITER, "Aiter is not available"
block_out, block_lse = flash_attn_func_aiter(
q,
k,
v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
return_lse=True,
)
return block_out, block_lse
def flashinfer_attn_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float = 0.0,
softmax_scale: float | None = None,
causal: bool = False,
window_size: tuple[int, int] = (-1, -1),
softcap: float | None = None,
alibi_slopes: torch.Tensor | None = None,
return_softmax: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert HAS_FLASHINFER, "FlashInfer is not available"
if q.ndim == 4:
if q.shape[0] > 1:
raise ValueError("batch size > 1 is not supported")
out, lse = single_prefill_with_kv_cache(
q[0],
k[0],
v[0],
sm_scale=softmax_scale,
causal=causal,
logits_soft_cap=softcap,
window_left=window_size[0],
return_lse=True,
)
lse = lse.transpose(0, 1)
out, lse = out.unsqueeze(0), lse.unsqueeze(0)
elif q.ndim == 3:
out, lse = single_prefill_with_kv_cache(
q,
k,
v,
sm_scale=softmax_scale,
causal=causal,
logits_soft_cap=softcap,
window_left=window_size[0],
return_lse=True,
)
lse = lse.transpose(0, 1)
else:
raise ValueError(f"Invalid input shape: {q.shape}")
lse = lse / _LOG2_E
return out, lse
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Jiarui Fang.
# Adapted from https://github.com/feifeibear/long-context-attention
from collections.abc import Callable
from enum import Enum
from functools import partial
import torch
from .ring_globals import (
HAS_SAGE_ATTENTION,
HAS_SPARSE_SAGE_ATTENTION,
)
from .ring_kernels import (
flash_attn3_func_forward,
flash_attn_forward,
flash_attn_forward_aiter,
flashinfer_attn_forward,
pytorch_attn_forward,
)
if HAS_SAGE_ATTENTION:
import sageattention
if HAS_SPARSE_SAGE_ATTENTION:
from spas_sage_attn.autotune import SparseAttentionMeansim
class AttnType(Enum):
AITER = "aiter"
FA = "fa"
FA3 = "fa3"
FLASHINFER = "flashinfer"
TORCH = "torch"
SAGE_AUTO = "sage_auto"
SAGE_FP16 = "sage_fp16"
SAGE_FP16_TRITON = "sage_fp16_triton"
SAGE_FP8 = "sage_fp8"
SAGE_FP8_SM90 = "sage_fp8_sm90"
SPARSE_SAGE = "sparse_sage"
@classmethod
def from_string(cls, s: str):
for member in cls:
if member.value == s:
return member
raise ValueError(f"'{s}' is not a valid {cls.__name__}")
def select_flash_attn_impl(
impl_type: AttnType,
stage: str = "fwd-only",
attn_processor: torch.nn.Module | None = None,
) -> Callable[..., tuple[torch.Tensor, torch.Tensor | None]]:
"""Select attention implementation for forward pass (inference only).
Args:
impl_type: The attention implementation type.
stage: Must be "fwd-only" (backward not supported for inference).
attn_processor: Optional custom attention processor.
Returns:
Callable[..., tuple[torch.Tensor, torch.Tensor | None]]: The attention
forward function for the specified implementation.
"""
if stage != "fwd-only":
raise ValueError(f"Only 'fwd-only' stage is supported for inference. Got: {stage}")
if impl_type == AttnType.AITER:
return flash_attn_forward_aiter
elif impl_type == AttnType.FA:
return flash_attn_forward
elif impl_type == AttnType.FA3:
return flash_attn3_func_forward
elif impl_type == AttnType.FLASHINFER:
return flashinfer_attn_forward
elif impl_type == AttnType.TORCH:
return pytorch_attn_forward
elif impl_type == AttnType.SAGE_AUTO:
if not HAS_SAGE_ATTENTION:
raise ImportError("SageAttention is not available!")
return partial(
sageattention.sageattn,
tensor_layout="NHD",
return_lse=True,
)
elif impl_type == AttnType.SAGE_FP16:
if not HAS_SAGE_ATTENTION:
raise ImportError("SageAttention is not available!")
return partial(
sageattention.sageattn_qk_int8_pv_fp16_cuda,
pv_accum_dtype="fp32",
tensor_layout="NHD",
return_lse=True,
)
elif impl_type == AttnType.SAGE_FP16_TRITON:
if not HAS_SAGE_ATTENTION:
raise ImportError("SageAttention is not available!")
return partial(
sageattention.sageattn_qk_int8_pv_fp16_triton,
tensor_layout="NHD",
return_lse=True,
)
elif impl_type == AttnType.SAGE_FP8:
if not HAS_SAGE_ATTENTION:
raise ImportError("SageAttention is not available!")
return partial(
sageattention.sageattn_qk_int8_pv_fp8_cuda,
pv_accum_dtype="fp32+fp32",
tensor_layout="NHD",
return_lse=True,
)
elif impl_type == AttnType.SAGE_FP8_SM90:
if not HAS_SAGE_ATTENTION:
raise ImportError("SageAttention is not available!")
return partial(
sageattention.sageattn_qk_int8_pv_fp8_cuda_sm90,
pv_accum_dtype="fp32+fp32",
tensor_layout="NHD",
return_lse=True,
)
elif impl_type == AttnType.SPARSE_SAGE:
if not HAS_SPARSE_SAGE_ATTENTION:
raise ImportError("SparseSageAttention is not available!")
if not isinstance(attn_processor, SparseAttentionMeansim):
raise ImportError("SparseSageAttention is only available with a SparseAttentionProcessor class passed in")
def fn(q, k, v, causal=False, softmax_scale=None, *args, **kwargs):
return (
attn_processor(
q,
k,
v,
is_causal=causal,
scale=softmax_scale,
tensor_layout="NHD",
),
None,
)
return fn
elif attn_processor is not None:
return attn_processor
else:
raise ValueError(f"Unknown flash attention implementation: {impl_type}")
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Jiarui Fang.
# Adapted from https://github.com/feifeibear/long-context-attention
import torch
import torch.nn.functional as F
__all__ = ["update_out_and_lse", "flatten_varlen_lse", "unflatten_varlen_lse"]
# Remove torch.jit.script for debugging and flexible shape handling
def _update_out_and_lse(
out: torch.Tensor,
lse: torch.Tensor,
block_out: torch.Tensor,
block_lse: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
block_out = block_out.to(torch.float32)
B, S, H, D = out.shape
# --- Shape Correction Logic for block_lse ---
# Goal: block_lse should be (B, S, H, 1) to match out (B, S, H, D)
# Debug info
# print(f"DEBUG _update: out={out.shape}, block_lse={block_lse.shape}")
# Case 0: If block_lse is already 4D, check if it matches
if block_lse.dim() == 4:
if block_lse.shape[1] == S and block_lse.shape[2] == H:
pass # Good
elif block_lse.shape[1] == H and block_lse.shape[2] == S:
block_lse = block_lse.transpose(1, 2)
elif block_lse.shape[1] == H and block_lse.shape[2] >= S: # Padding case
block_lse = block_lse[:, :, :S, :].transpose(1, 2)
# If shape is (B, H, S, 1) but expected (B, S, H, 1) because out is (B, S, H, D)
elif block_lse.shape[1] == H and block_lse.shape[2] == S and block_lse.shape[3] == 1:
block_lse = block_lse.transpose(1, 2)
# Case 1: block_lse is 3D (B, H, S) or (B, S, H) or (B, ?, ?)
elif block_lse.dim() == 3:
# Check for (B, H, S) - Standard SDPA/FA output
if block_lse.shape[1] == H and block_lse.shape[2] == S:
block_lse = block_lse.transpose(1, 2).unsqueeze(-1)
# Check for (B, S, H)
elif block_lse.shape[1] == S and block_lse.shape[2] == H:
block_lse = block_lse.unsqueeze(-1)
# Check for Padding: (B, H, S_pad) where S_pad >= S
elif block_lse.shape[1] == H and block_lse.shape[2] >= S:
# print(f"DEBUG: Trimming padding from lse. {block_lse.shape} -> S={S}")
block_lse = block_lse[:, :, :S].transpose(1, 2).unsqueeze(-1)
# Check for weird case: (B, S, H_pad) ? Unlikely for LSE but possible
elif block_lse.shape[1] == S and block_lse.shape[2] >= H:
block_lse = block_lse[:, :, :H].unsqueeze(-1)
# Check for flipped weird case: (B, S_pad, H)
elif block_lse.shape[1] >= S and block_lse.shape[2] == H:
block_lse = block_lse[:, :S, :].unsqueeze(-1)
# --- Shape Correction for lse (internal state) ---
# Ensure lse matches block_lse's corrected shape (B, S, H, 1)
if lse.shape != block_lse.shape:
# If lse was initialized with wrong shape, try to fix it
if lse.dim() == 4 and lse.shape[1] == block_lse.shape[2] and lse.shape[2] == block_lse.shape[1]:
lse = lse.transpose(1, 2)
elif lse.shape[1] >= S: # slice if lse was initialized with padding
lse = lse[:, :S, :, :]
# Final check
if lse.shape != block_lse.shape:
# Force broadcast if possible?
pass
try:
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
lse = lse - F.logsigmoid(lse - block_lse)
except RuntimeError as e:
print(f"ERROR in _update_out_and_lse: {e}")
print(f"out: {out.shape}, lse: {lse.shape}")
print(f"block_out: {block_out.shape}, block_lse: {block_lse.shape}")
# raise e
raise e
return out, lse
def update_out_and_lse(
out: torch.Tensor | None,
lse: torch.Tensor | None,
block_out: torch.Tensor,
block_lse: torch.Tensor,
slice_=None,
) -> tuple[torch.Tensor, torch.Tensor]:
if out is None:
if slice_ is not None:
raise RuntimeError("first update_out_and_lse should not pass slice_ args")
out = block_out.to(torch.float32)
# Initialize LSE with robust logic (same as _update)
B, D1, D2, D3 = out.shape
S_guess = D1
H_guess = D2
if block_lse.dim() == 3:
if block_lse.shape[1] == H_guess and block_lse.shape[2] == S_guess:
lse = block_lse.transpose(1, 2).unsqueeze(-1)
elif block_lse.shape[1] == S_guess and block_lse.shape[2] == H_guess:
lse = block_lse.unsqueeze(-1)
elif block_lse.shape[1] == H_guess and block_lse.shape[2] >= S_guess: # Padding
lse = block_lse[:, :, :S_guess].transpose(1, 2).unsqueeze(-1)
elif block_lse.shape[1] == S_guess and block_lse.shape[2] >= H_guess: # Padding/Weird
lse = block_lse[:, :, :H_guess].unsqueeze(-1)
elif block_lse.shape[1] >= S_guess and block_lse.shape[2] == H_guess:
lse = block_lse[:, :S_guess, :].unsqueeze(-1)
# Reverse case: What if out is (B, H, S, D) so S=D2, H=D1?
elif block_lse.shape[1] == D1 and block_lse.shape[2] >= D2: # Matches (H, S)
# Then out is (B, H, S, D). We should transpose out!
out = out.transpose(1, 2)
lse = block_lse[:, :, :D2].transpose(1, 2).unsqueeze(-1) # (B, S, H, 1)
else:
# Fallback
lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
else:
# Case 0: If block_lse is already 4D, check if it matches
if block_lse.dim() == 4:
if block_lse.shape[1] == S_guess and block_lse.shape[2] == H_guess:
lse = block_lse
elif block_lse.shape[1] == H_guess and block_lse.shape[2] == S_guess:
lse = block_lse.transpose(1, 2)
elif block_lse.shape[1] == H_guess and block_lse.shape[2] >= S_guess: # Padding case
lse = block_lse[:, :, :S_guess, :].transpose(1, 2)
elif block_lse.shape[1] == D1 and block_lse.shape[2] >= D2: # Matches (H, S)
# Then out is (B, H, S, D). We should transpose out!
out = out.transpose(1, 2)
lse = block_lse[:, :, :D2].transpose(1, 2) # (B, S, H, 1)
else:
lse = block_lse
else:
lse = block_lse
elif slice_ is not None:
slice_out, slice_lse = out[slice_], lse[slice_]
slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse)
out[slice_], lse[slice_] = slice_out, slice_lse
else:
out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
return out, lse
def flatten_varlen_lse(lse, cu_seqlens):
new_lse = []
for i in range(len(cu_seqlens) - 1):
start, end = cu_seqlens[i], cu_seqlens[i + 1]
new_lse.append(lse[i, :, : end - start])
return torch.cat(new_lse, dim=1)
def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int):
num_seq = len(cu_seqlens) - 1
num_head = lse.shape[-2]
new_lse = torch.empty((num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device)
for i in range(num_seq):
start, end = cu_seqlens[i], cu_seqlens[i + 1]
new_lse[i, : end - start] = lse[start:end]
return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment