Commit 38d80967 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori

parents 33650733 880c741b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import glob
requires_files = glob.glob('requirements/*.txt')
requires_files += ["pyproject.toml"]
for file in requires_files:
print(f">>> cleaning {file}")
with open(file) as f:
lines = f.readlines()
if "torch" in "".join(lines).lower():
print("removed:")
with open(file, 'w') as f:
for line in lines:
if 'torch' not in line.lower():
f.write(line)
else:
print(line.strip())
print(f"<<< done cleaning {file}")
print()
print("vLLM is now using 'uv' to disable build isolation for 'torch'.")
print("Please instead install vLLM with 'uv pip install -e .' (must use 'uv')")
......@@ -14,6 +14,8 @@ import typing
import vllm.env_override # noqa: F401
MODULE_ATTRS = {
"bc_linter_skip": "._bc_linter:bc_linter_skip",
"bc_linter_include": "._bc_linter:bc_linter_include",
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
"EngineArgs": ".engine.arg_utils:EngineArgs",
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
......@@ -54,6 +56,8 @@ if typing.TYPE_CHECKING:
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from ._bc_linter import bc_linter_include, bc_linter_skip
else:
def __getattr__(name: str) -> typing.Any:
......@@ -70,6 +74,8 @@ else:
__all__ = [
"__version__",
"bc_linter_skip",
"bc_linter_include",
"__version_tuple__",
"LLM",
"ModelRegistry",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# vllm/_bc_linter.py
from __future__ import annotations
from typing import Any, Callable, TypeVar, overload
T = TypeVar("T")
@overload
def bc_linter_skip(obj: T) -> T:
...
@overload
def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]:
...
def bc_linter_skip(obj: Any = None, *, reason: str | None = None):
"""
No-op decorator to mark symbols/files for BC-linter suppression.
Usage:
@bc_linter_skip
def legacy_api(...): ...
"""
def _wrap(x: T) -> T:
return x
return _wrap if obj is None else obj
@overload
def bc_linter_include(obj: T) -> T:
...
@overload
def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]:
...
def bc_linter_include(obj: Any = None, *, reason: str | None = None):
"""
Usage:
@bc_linter_include
def public_api(...): ...
"""
def _wrap(x: T) -> T:
return x
return _wrap if obj is None else obj
__all__ = ["bc_linter_skip", "bc_linter_include"]
......@@ -280,6 +280,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
def poly_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, epsilon: float) -> None:
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
input_contiguous = input.contiguous()
torch.ops._C.poly_norm(out, input_contiguous, weight, bias, epsilon)
def apply_repetition_penalties_torch(
logits: torch.Tensor, prompt_mask: torch.Tensor,
output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None:
......@@ -715,6 +722,7 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability)
def cutlass_sparse_compress(a: torch.Tensor) \
-> tuple[torch.Tensor, torch.Tensor]:
"""
......@@ -1630,20 +1638,6 @@ def concat_and_cache_mla(
scale)
def cp_fused_concat_and_cache_mla(
kv_c: torch.Tensor,
k_pe: torch.Tensor,
cp_local_token_select_indices: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.cp_fused_concat_and_cache_mla(
kv_c, k_pe, cp_local_token_select_indices, kv_cache, slot_mapping,
kv_cache_dtype, scale)
def copy_blocks(key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
block_mapping: torch.Tensor) -> None:
......@@ -1852,13 +1846,13 @@ def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
return out
def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
q_pe: torch.Tensor,
def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor,
q_nope: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor, page_table: torch.Tensor,
workspace: torch.Tensor, scale: float,
num_kv_splits: int) -> torch.Tensor:
torch.ops._C.sm100_cutlass_mla_decode(out, q_nope, q_pe,
torch.ops._C.sm100_cutlass_mla_decode(out, lse, q_nope, q_pe,
kv_c_and_k_pe_cache, seq_lens,
page_table, workspace, scale,
num_kv_splits)
......@@ -1933,6 +1927,35 @@ class CPUDNNLGEMMHandler:
torch.ops._C.release_dnnl_matmul_handler(self.handler)
if hasattr(torch.ops._C, "create_onednn_mm_handler"):
_supports_onednn = True
else:
_supports_onednn = False
def create_onednn_mm(
weight: torch.Tensor, # [K, N]
primitive_cache_size: int = 128,
) -> CPUDNNLGEMMHandler:
handler = CPUDNNLGEMMHandler()
handler.k, handler.n = weight.size()
handler.handler = torch.ops._C.create_onednn_mm_handler(
weight, primitive_cache_size)
return handler
def onednn_mm(
dnnl_handler: CPUDNNLGEMMHandler,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype)
torch.ops._C.onednn_mm(output, x.reshape(-1, dnnl_handler.k), bias,
dnnl_handler.handler)
return output
def create_onednn_scaled_mm(
weight: torch.Tensor, # [K, N]
weight_scales: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Optional, Union
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
......@@ -241,10 +242,9 @@ class ipex_ops:
k_scale_float: float = 1.0,
v_scale_float: float = 1.0,
) -> None:
assert kv_cache_dtype == "auto"
# TODO: support FP8 kv cache.
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slot_mapping)
key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
k_scale_float, v_scale_float)
@staticmethod
def flash_attn_varlen_func(
......@@ -349,3 +349,56 @@ class ipex_ops:
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None:
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
@staticmethod
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function is designed for both static and dynamic quantization:
If you provide the scale, it will use static scaling and if you omit
it, the scale will be determined dynamically. Currently, XPU platform
only supports dynamic quantization. The function also allows optional
padding of the output tensors for downstream kernels that will benefit
from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
assert scale is None, "only dynamic fp8 quantization supported on XPU"
assert not use_per_token_if_dynamic, (
"per token dynamic fp8 quantization not supported on XPU")
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale)
return output, scale
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
import torch
......@@ -11,17 +12,29 @@ from .base import get_vllm_public_assets
VLM_IMAGES_DIR = "vision_model_images"
ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato"]
ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato",
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk",
"Grayscale_8bits_palette_sample_image",
"1280px-Venn_diagram_rgb", "RGBA_comp", "237-400x300",
"231-200x300", "27-500x500", "17-150x600",
"handelsblatt-preview", "paper-11"]
@dataclass(frozen=True)
class ImageAsset:
name: ImageAssetName
def get_path(self, ext: str) -> Path:
"""
Return s3 path for given image.
"""
return get_vllm_public_assets(filename=f"{self.name}.{ext}",
s3_prefix=VLM_IMAGES_DIR)
@property
def pil_image(self) -> Image.Image:
image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
s3_prefix=VLM_IMAGES_DIR)
def pil_image(self, ext="jpg") -> Image.Image:
image_path = self.get_path(ext)
return Image.open(image_path)
@property
......@@ -29,6 +42,9 @@ class ImageAsset:
"""
Image embeddings, only used for testing purposes with llava 1.5.
"""
image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
s3_prefix=VLM_IMAGES_DIR)
image_path = self.get_path('pt')
return torch.load(image_path, map_location="cpu", weights_only=True)
def read_bytes(self, ext: str) -> bytes:
p = Path(self.get_path(ext))
return p.read_bytes()
......@@ -110,22 +110,23 @@ class VideoAsset:
def filename(self) -> str:
return self._NAME_TO_FILE[self.name]
@property
def video_path(self) -> str:
return download_video_asset(self.filename)
@property
def pil_images(self) -> list[Image.Image]:
video_path = download_video_asset(self.filename)
ret = video_to_pil_images_list(video_path, self.num_frames)
ret = video_to_pil_images_list(self.video_path, self.num_frames)
return ret
@property
def np_ndarrays(self) -> npt.NDArray:
video_path = download_video_asset(self.filename)
ret = video_to_ndarrays(video_path, self.num_frames)
ret = video_to_ndarrays(self.video_path, self.num_frames)
return ret
@property
def metadata(self) -> dict[str, Any]:
video_path = download_video_asset(self.filename)
ret = video_get_metadata(video_path)
ret = video_get_metadata(self.video_path)
return ret
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
......@@ -134,5 +135,4 @@ class VideoAsset:
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
"""
video_path = download_video_asset(self.filename)
return librosa.load(video_path, sr=sampling_rate)[0]
return librosa.load(self.video_path, sr=sampling_rate)[0]
......@@ -257,6 +257,32 @@ class AttentionLayer(Protocol):
class AttentionImpl(ABC, Generic[T]):
# Whether the attention impl can return the softmax lse for decode.
# Some features like decode context parallelism require the softmax lse.
can_return_lse_for_decode: bool = False
# some attention backends might not always want to return lse
# even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode: bool = False
dcp_world_size: int
dcp_rank: int
def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this
self = super().__new__(cls)
try:
from vllm.distributed.parallel_state import get_dcp_group
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
and self.can_return_lse_for_decode
return self
@abstractmethod
def __init__(
self,
......
......@@ -734,6 +734,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
)
assert prefill_output.shape == output[:
num_prefill_tokens].shape
......@@ -755,6 +756,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
).squeeze(1)
except Exception as e:
logger.error("Error in PagedAttention.forward_decode: %s",
......@@ -787,6 +789,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
).squeeze(1)
return output
......
......@@ -17,6 +17,7 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.platforms.cuda import CudaPlatform
class FlashMLABackend(MLACommonBackend):
......@@ -181,6 +182,16 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert is_flashmla_supported(), \
"FlashMLA is not supported on this device"
# disallow FlashMLA on NVIDIA Blackwell (SM 10.0+) GPUs
# context:
# https://github.com/deepseek-ai/FlashMLA/issues/83
# https://github.com/vllm-project/vllm/issues/24513
if CudaPlatform.has_device_capability(100):
raise NotImplementedError(
"FlashMLA is temporarily disabled on Blackwell (SM 10.0). "
"Please use CUTLASS_MLA or TRITON_MLA instead. "
"Example: `export VLLM_ATTENTION_BACKEND=CUTLASS_MLA`")
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
......
......@@ -824,7 +824,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
and context_lens_tensor is not None \
and context_lens_tensor[:self.num_prefills].max() > 0:
# NOTE: it is recommend you read the `Chunked Prefill` section in
# NOTE: it is recommended you read the `Chunked Prefill` section in
# the comment at the top of the file before trying to understand
# the following code
......@@ -1056,7 +1056,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
if self.use_llama_nn and self.kv_b_proj.quant_method is None:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj)
......
......@@ -360,13 +360,13 @@ class MultiHeadAttention(nn.Module):
# currently, only torch_sdpa is supported on rocm
self.attn_backend = _Backend.TORCH_SDPA
else:
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
_Backend.FLEX_ATTENTION):
backend = _Backend.XFORMERS
self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
} else _Backend.TORCH_SDPA
_Backend.TORCH_SDPA,
_Backend.TORCH_SDPA_VLLM_V1,
_Backend.XFORMERS,
_Backend.PALLAS_VLLM_V1,
_Backend.ROCM_AITER_FA,
} else current_platform.get_vit_attn_backend()
if (self.attn_backend == _Backend.XFORMERS
and not check_xformers_availability()):
......@@ -399,7 +399,8 @@ class MultiHeadAttention(nn.Module):
key,
value,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
elif (self.attn_backend == _Backend.TORCH_SDPA
or self.attn_backend == _Backend.TORCH_SDPA_VLLM_V1):
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))
out = F.scaled_dot_product_attention(query,
......@@ -413,6 +414,19 @@ class MultiHeadAttention(nn.Module):
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
out = flash_attn_varlen_func(query,
key,
value,
softmax_scale=self.scale)
else:
# ViT attention hasn't supported this backend yet
raise NotImplementedError(
f"ViT attention hasn't supported {self.attn_backend} "
f"backend yet.")
return out.reshape(bsz, q_len, -1)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy
from typing import Optional
import numpy as np
import torch
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, VllmConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
subclass_attention_backend)
from vllm.v1.kv_cache_interface import CrossAttentionSpec
logger = init_logger(__name__)
def _get_max_encoder_len(vllm_config: VllmConfig) -> int:
return MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(
vllm_config.model_config)
def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray,
block_table_tensor: torch.Tensor,
kv_cache_spec: CrossAttentionSpec,
device: torch.device) -> torch.Tensor:
"""Get cross-attention slot mappings."""
block_size = kv_cache_spec.block_size
slot_mappings = []
# Find indices with non-zero encoder sequence lengths
# The majority of parallel requests will be running the
# decoder, so this list should be relatively small.
active_indices = np.nonzero(encoder_seq_lens)[0]
for req_index in active_indices:
encoder_seq_len = encoder_seq_lens[req_index].item()
# Calculate the number of blocks needed for this request
num_blocks_needed = cdiv(encoder_seq_len, block_size)
# Get the block IDs for this request from the tensor
req_block_ids = block_table_tensor[req_index]
# Get only the blocks we need (first num_blocks_needed blocks)
needed_block_ids = req_block_ids[:num_blocks_needed]
# All needed blocks are allocated
i_values = torch.arange(encoder_seq_len,
dtype=torch.int64,
device=device)
block_indices = i_values // block_size
block_offsets = i_values % block_size
block_numbers = needed_block_ids[block_indices]
slot_mapping = block_numbers * block_size + block_offsets
slot_mappings.append(slot_mapping)
if slot_mappings:
return torch.cat(slot_mappings)
else:
return torch.empty(0, dtype=torch.int64, device=device)
@functools.lru_cache
def create_cross_attention_backend(
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
prefix = "CrossAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class CrossAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
new_metadata = copy(common_attn_metadata)
new_metadata.causal = False
max_encoder_len = _get_max_encoder_len(self.vllm_config)
new_metadata.max_seq_len = max_encoder_len
new_metadata.seq_lens = torch.full(
(new_metadata.num_reqs, ),
max_encoder_len,
dtype=torch.int32,
device=self.device,
)
new_metadata.seq_lens_cpu = torch.full(
(new_metadata.num_reqs, ),
max_encoder_len,
dtype=torch.int32,
device="cpu",
)
new_metadata.slot_mapping = _get_cross_slot_mapping(
new_metadata.encoder_seq_lens, new_metadata.block_table_tensor,
self.kv_cache_spec, self.device)
return super().build(common_prefix_len, new_metadata, fast_build)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=CrossAttentionBuilder)
return attn_backend
class CrossAttention(Attention):
"""
Cross-attention for encoder-decoder models.
Handles attention between decoder queries and encoder keys/values.
"""
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
cache_config: Optional[CacheConfig] = None,
attn_type: Optional[str] = None,
**kwargs):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(head_size, dtype,
kv_cache_dtype,
block_size)
attn_backend = create_cross_attention_backend(
underlying_attn_backend)
else:
# in v0 cross attention is handled inside the backends
attn_backend = None
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_DECODER, (
"CrossAttention only supports AttentionType.ENCODER_DECODER")
super().__init__(num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
attn_type=AttentionType.ENCODER_DECODER,
**kwargs)
......@@ -15,6 +15,8 @@ from vllm.triton_utils import tl, triton
from .prefix_prefill import context_attention_fwd
float8_info = torch.finfo(current_platform.fp8_dtype())
@triton.jit
def cdiv_fn(x, y):
......@@ -34,6 +36,7 @@ def kernel_paged_attention_2d(
scale, # float32
k_scale, # float32
v_scale, # float32
out_scale_inv,
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int
......@@ -60,7 +63,9 @@ def kernel_paged_attention_2d(
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
USE_SINKS: tl.constexpr, # bool
):
USE_FP8: tl.constexpr,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max):
seq_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
......@@ -204,6 +209,9 @@ def kernel_paged_attention_2d(
# epilogue
acc = acc / L[:, None]
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
output_offset = (cur_batch_in_all_start_index * output_stride_0 +
query_head_idx * output_stride_1)
......@@ -234,6 +242,7 @@ def chunked_prefill_paged_decode(
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
output_scale=None,
# Optional tensor for sinks
sinks=None,
):
......@@ -266,6 +275,7 @@ def chunked_prefill_paged_decode(
sliding_window=sliding_window,
sm_scale=sm_scale,
skip_decode=True,
fp8_out_scale=output_scale,
sinks=sinks,
)
......@@ -316,7 +326,7 @@ def chunked_prefill_paged_decode(
tmp_output = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions,
head_size),
dtype=output.dtype,
dtype=query.dtype,
device=output.device,
)
exp_sums = torch.empty(
......@@ -345,6 +355,7 @@ def chunked_prefill_paged_decode(
kv_cache_dtype=kv_cache_dtype,
k_scale=k_scale,
v_scale=v_scale,
fp8_out_scale=output_scale,
)
else:
kernel_paged_attention_2d[(
......@@ -362,6 +373,8 @@ def chunked_prefill_paged_decode(
scale=sm_scale,
k_scale=k_scale,
v_scale=v_scale,
out_scale_inv=1.0 /
output_scale if output_scale is not None else 1.0,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded,
......@@ -388,4 +401,5 @@ def chunked_prefill_paged_decode(
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
USE_SINKS=sinks is not None,
USE_FP8=output_scale is not None,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.triton_utils import tl, triton
@triton.jit
def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr,
vlse_ptr, outputs_stride_B, outputs_stride_H,
outputs_stride_D, lses_stride_N, lses_stride_B,
lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr,
N_ROUNDED: tl.constexpr):
"""
Apply the all-gathered lses to correct each local rank's attention
output. we still need perform a cross-rank reduction to obtain the
final attention output.
Args:
output: [ B, H, D ]
lses : [ N, B, H ]
cp, batch, q_heads, v_head_dim
Return:
output: [ B, H, D ]
lse : [ B, H ]
"""
batch_idx = tl.program_id(axis=0).to(tl.int64)
head_idx = tl.program_id(axis=1).to(tl.int64)
d_offsets = tl.arange(0, HEAD_DIM)
num_n_offsets = tl.arange(0, N_ROUNDED)
# shape = [N]
lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \
lses_stride_B + head_idx * lses_stride_H
# calc final lse
lse = tl.load(lses_ptr + lse_offsets)
lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse)
lse_max = tl.max(lse, axis=0)
lse -= lse_max
lse_exp = tl.exp(lse)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log(lse_acc)
lse += lse_max
lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
tl.store(vlse_ptr + lse_offsets, lse)
# shape = [D]
output_offsets = batch_idx * outputs_stride_B + \
head_idx * outputs_stride_H + \
d_offsets * outputs_stride_D
# correct output
lse_offset = lse_idx * lses_stride_N + batch_idx * \
lses_stride_B + head_idx * lses_stride_H
lse_tmp = tl.load(lses_ptr + lse_offset)
lse_finally = lse_tmp - lse
lse_finally = tl.where(
(lse_finally != lse_finally) | (lse_finally == float('inf')),
-float('inf'), lse_finally)
factor = tl.exp(lse_finally)
output = tl.load(outputs_ptr + output_offsets)
output = output * factor
tl.store(new_output_ptr + output_offsets, output)
class CPTritonContext:
""" The CPTritonContext is used to avoid recompilation of the Triton JIT.
"""
def __init__(self):
self.inner_kernel = None
def call_kernel(self, kernel, grid, *regular_args, **const_args):
if self.inner_kernel is None:
self.inner_kernel = kernel[grid](*regular_args, **const_args)
else:
self.inner_kernel[grid](*regular_args)
def correct_attn_out(out: torch.Tensor, lses: torch.Tensor, cp_rank: int,
ctx: CPTritonContext):
"""
Apply the all-gathered lses to correct each local rank's attention
output. we still need perform a cross-rank reduction to obtain the
final attention output.
Args:
output: [ B, H, D ]
lses : [ N, B, H ]
Return:
output: [ B, H, D ]
lse : [ B, H ]
"""
if ctx is None:
ctx = CPTritonContext()
lse = torch.empty_like(lses[0])
grid = (out.shape[0], out.shape[1], 1)
regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(),
cp_rank)
const_args = {
"HEAD_DIM": out.shape[-1],
"N_ROUNDED": lses.shape[0],
}
ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args,
**const_args)
return out, lse
def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
if cp_group.world_size == 1:
return cp_attn_out
if ctx is None:
ctx = CPTritonContext()
lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape,
dtype=cp_attn_lse.dtype,
device=cp_attn_lse.device)
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
out = cp_group.reduce_scatter(out, dim=1)
return out
......@@ -105,7 +105,9 @@ def flash_mla_with_kvcache(
descale_q,
descale_k,
)
return out, softmax_lse
# Note(hc): need revisit when we support DCP with decode query_len > 1.
return out.squeeze(1), softmax_lse.squeeze(-1)
#
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import numpy as np
import torch
from neuronxcc import nki
from neuronxcc.nki.language import par_dim
from vllm.utils import cdiv
def is_power_of_2(x):
return x > 0 and (x & (x - 1)) == 0
@nki.jit
def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
"""
Load block tables from HBM into SRAM
`block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`.
In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension.
"""
B_P_SIZE = 128
# reshape as `(num_tiles, num_blocks_per_tile)`
assert len(block_tables_hbm.shape) == 1
(num_total_blocks, ) = block_tables_hbm.shape
assert num_blocks_per_tile * num_tiles == num_total_blocks
block_tables_hbm = block_tables_hbm.reshape(
(num_tiles, num_blocks_per_tile))
block_tables_sbuf = nl.zeros(
(cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
dtype=nl.int32,
)
for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(num_blocks_per_tile)[None, :]
block_tables_sbuf[i, i_p, i_f] = nl.load(
block_tables_hbm[i_p + i * B_P_SIZE, i_f],
dtype=nl.int32,
mask=(i_p + i * B_P_SIZE < num_tiles),
)
return block_tables_sbuf
@nki.jit
def transform_block_tables_for_indirect_load(
block_tables,
block_size_tiling_factor,
num_head,
head_id,
):
"""
This function does two things:
1. calculate new `block_tables` for a `head_id` after flattening
`num_block`, `num_head`, and `block_size_tiling_factor` dimensions
2. transpose the result so that `block_table` for each tile is mapped to
SBUF Partition dimension for vectorized DMA
Tiling trick to further improve DMA performance:
Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M
blocks of a given `head_id` from HBM, the load `cache[block_tables,
head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not
fully utilize hardware parallelization. The solution is to tile `block_size`
into `(block_size_tiling_factor, tiled_block_size)` s.t. `M *
block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape
`(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`.
Note:
We don't further tile D dimension as small DMA size also hurts performance.
"""
B_P_SIZE = 128
num_partitions, num_tiles_per_partition, num_blocks_per_tile = (
block_tables.shape)
assert num_tiles_per_partition == B_P_SIZE
assert is_power_of_2(
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
num_loads = cdiv(num_blocks_per_tile, B_P_SIZE)
block_tables_transposed = nl.ndarray(
(
num_loads,
par_dim(B_P_SIZE),
num_partitions * num_tiles_per_partition,
),
dtype=nl.int32,
)
# prepare iota ahead of time to avoid repeatedly using Gpsimd
if num_head > 1:
head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1))
head_id = nl.transpose(
head_id.broadcast_to((1, num_tiles_per_partition)))
if num_blocks_per_tile > 1:
head_id = head_id.broadcast_to(
(num_tiles_per_partition, num_blocks_per_tile))
if block_size_tiling_factor > 1:
broadcast_shape = (
num_tiles_per_partition,
num_blocks_per_tile,
block_size_tiling_factor,
)
offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :],
dtype=nl.int32).broadcast_to(broadcast_shape)
for partition_id in nl.affine_range(num_partitions):
block_tables_partition = block_tables[partition_id]
if num_head > 1:
# fuse num_block and num_head dimension
block_tables_partition = block_tables_partition * num_head + head_id
if block_size_tiling_factor > 1:
# need to apply block size tiling trick
assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE
block_tables_partition = ((block_tables_partition *
block_size_tiling_factor).reshape(
(num_tiles_per_partition,
num_blocks_per_tile,
1)).broadcast_to(broadcast_shape))
new_block_tables = block_tables_partition + offset
new_block_tables = new_block_tables.reshape(
(num_tiles_per_partition, B_P_SIZE))
else:
new_block_tables = block_tables_partition
# transpose the block table so that it can be used by vector DGE
for i in nl.affine_range(num_loads):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = (partition_id * num_tiles_per_partition +
nl.arange(num_tiles_per_partition)[None, :])
block_tables_transposed[i, i_p, i_f] = nl.transpose(
new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)])
return block_tables_transposed
@nki.jit
def load_kv_tile_from_cache(
cur_k_tile,
cur_v_tile,
kv_cache,
block_tables,
large_k_tile_idx,
num_blocks_per_large_tile,
tiled_block_size,
B_P_SIZE,
B_D_SIZE,
):
"""
Load KV cache and transform Key and Value into layout required by Matmul
Vectorized DMA Load layout:
Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
Layout used by attention matmuls:
Key: (par_dim(B_D_SIZE), seqlen_kv)
Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE)
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
"""
# load key cache
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
for load_idx in nl.affine_range(num_loads):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p,
large_k_tile_idx], i_f])
if cur_k_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
# Transpose SBUF tensor using PE
for tb_i in nl.affine_range(tiled_block_size):
cur_k_tile[
:,
nl.ds(
load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE,
B_P_SIZE,
),
] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)])
# load value cache
for load_idx in nl.affine_range(num_loads):
loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p,
large_k_tile_idx], i_f])
if cur_v_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
cur_v_tile[
:,
nl.ds(
load_idx * tiled_block_size * B_D_SIZE,
tiled_block_size * B_D_SIZE,
),
] = loaded
@nki.jit
def transpose_p_local(p_local_transposed,
p_local,
LARGE_TILE_SZ,
B_F_SIZE=512):
for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
if nisa.get_nc_version() == nisa.nc_version.gen3:
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
buffer=nl.sbuf,
dtype=p_local.dtype)
else:
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
buffer=nl.psum,
dtype=np.float32)
for j in nl.affine_range(B_F_SIZE // 128):
j_128_slice = nl.ds(j * 128, 128)
i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128)
if nisa.get_nc_version() == nisa.nc_version.gen3:
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
p_local[:, i_j_128_slice])
else:
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
p_local[:, i_j_128_slice])
p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
p_local_t_tmp, dtype=p_local_transposed.dtype)
@nki.jit
def _flash_attention_core(
q_local_tile,
k,
v,
o_buffer,
l_buffer,
m_buffer,
kernel_dtype,
acc_type,
tile_mask,
use_causal_mask,
q_tile_idx=None,
initialize=False,
LARGE_TILE_SZ=2048,
B_P_SIZE=128,
B_F_SIZE=512,
B_D_SIZE=128,
qk_res_buffer=None,
):
"""
The flash attention core function to calculate self attention between a tile
of q and a block of K and V.
The q_local_tile has (B_P_SIZE, B_D_SIZE)
The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will
be split into size B_F_SIZE tiles
The results are stored in the following three buffers
o_buffer: (B_P_SIZE, d)
l_buffer: (B_P_SIZE, 1)
m_buffer: (B_P_SIZE, 1)
All IO buffers are in SBUF.
"""
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
buffer=nl.sbuf,
dtype=acc_type)
max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile),
dtype=acc_type)
for k_i in nl.affine_range(num_k_tile_per_large_tile):
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
if use_causal_mask:
# mask are used to only apply computation to the lower half of the
# matrix, which reduce the arithmetic intensity by up to 50%
multiplication_required_selection = (q_tile_idx * B_P_SIZE
>= k_i * B_F_SIZE)
else:
multiplication_required_selection = True
if multiplication_required_selection:
qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
dtype=np.float32,
buffer=nl.psum) # (128, 512)
qk_psum[:, :] = nl.matmul(q_local_tile,
k[:, k_i_b_f_slice],
transpose_x=True) # (p(128), 512)
qk_res_buf[:, k_i_b_f_slice] = nl.where(
tile_mask[:, k_i_b_f_slice],
qk_psum[:, nl.ds(0, B_F_SIZE)],
-9984.0,
dtype=acc_type,
)
else:
qk_res_buf[:, k_i_b_f_slice] = -9984.0
# Calculate max of the current tile
max_local[:, k_i] = nisa.tensor_reduce(
np.max,
qk_res_buf[:, k_i_b_f_slice],
axis=(1, ),
dtype=acc_type,
negate=False,
)
if qk_res_buffer is not None:
qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :])
max_ = nisa.tensor_reduce(
np.max,
max_local[:, :],
axis=(1, ),
dtype=acc_type,
negate=False,
)
o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
dtype=o_buffer.dtype)
if initialize:
m_buffer[:, 0] = nl.copy(max_)
m_current = max_
else:
m_previous = nl.copy(m_buffer[:, 0])
m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1)
m_current = m_buffer[:, 0]
# Compute scaling factor
alpha = nisa.activation(
np.exp,
m_previous,
bias=-1 * m_current,
scale=1.0,
)
o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)
p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
p_partial_sum = nl.ndarray(
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE),
dtype=acc_type,
)
for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
# compute exp(qk - max)
# Compute partial row - tile sum of exp(qk - max))
# FIXME : Use activation accumulate to accumulate over k_r_i loop ?
p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce(
np.exp,
qk_res_buf[:, k_r_i_reduce_slice],
bias=-1 * m_current,
scale=1.0,
reduce_op=nl.add,
reduce_res=p_partial_sum[:, k_r_i],
dtype=kernel_dtype,
)
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)
p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
transpose_p_local(
p_local_transposed=p_local_transposed,
p_local=p_local,
LARGE_TILE_SZ=LARGE_TILE_SZ,
B_F_SIZE=B_F_SIZE,
)
pv_psum = nl.zeros(
(par_dim(B_P_SIZE), B_D_SIZE),
dtype=np.float32,
buffer=nl.psum,
)
for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
pv_psum[:, :] += nl.matmul(
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)],
transpose_x=True,
) # (128, 128) (p(Br), d)
if initialize:
o_buffer[:, :] = nl.copy(pv_psum[:, :])
l_buffer[:, 0] = nl.add(nl.log(ps), max_)
else:
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)
l_prev = l_buffer[:, 0]
l_exp = nl.add(
nl.exp(nl.subtract(l_prev, m_current)),
ps,
)
l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp))
@nki.jit
def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ):
B_P_SIZE = 128
B_D_SIZE = v_hbm_tile.shape[-1]
loaded = nl.load(v_hbm_tile[
nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE),
:,
])
if cur_v_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded
@nki.jit
def flash_paged_attention(
query,
key,
value,
kv_cache,
block_tables,
mask,
softmax_scale=None,
mixed_precision=True,
LARGE_TILE_SZ=2048,
return_debug_tensors=False,
):
"""
Flash PagedAttention Forward Kernel.
IO tensor layouts:
- query: shape (1, n_heads, d, seq_q)
- key: shape (1, n_kv_heads, d, seq_k)
- value: shape (1, n_kv_heads, seq_v, d)
- kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
- block_tables: (num_active_blocks, )
- mask: (seq_q, num_active_blocks * block_size + seq_q)
- o: shape (1, n_heads, seq_q, d)
- This kernel requires seq_k == seq_v
- We use continuous batching by default, so the batch dimension is
always 1, and different requests are concatenated along sequence
dimension.
- We use paged cache blocks (kv_cache) to store KV cache.
IO tensor dtypes:
- This kernel assumes all IO tensors have the same dtype except for
block_tables (int32) and mask (int32)
- If mixed_precision is True, then all Tensor Engine operation will be
performed in bfloat16 and accumulation will be performed in float32.
Otherwise the intermediates will be in the same type as the inputs.
Compile-time Constants:
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
is set to `true`, if false, we use same precision as input types
- LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention
computation reduction
GQA support Notes:
the spmd kernel for launching kernel should be on kv_heads instead of
nheads
Example usage:
MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
usage: `flash_fwd[b, h](q, k, v, ...)`
GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
usage: `flash_fwd[b, kv_h](q, k, v, ...)`
"""
B_F_SIZE = 512
B_P_SIZE = 128
b, h, d, seqlen_q = query.shape
B_D_SIZE = d
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
_, num_blocks, k_h, block_size, _ = kv_cache.shape
q_h_per_k_h = h // k_h
assert b == 1, f"invalid batch size {b=}"
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
cache_shape = (2, num_blocks, k_h, block_size, d)
assert (tuple(kv_cache.shape) == cache_shape
), f"{kv_cache.shape=} mismatch, expect {cache_shape}"
assert key is None or tuple(key.shape) == (
1,
k_h,
d,
seqlen_q,
), f"key shape {key.shape} mismatch!"
assert value is None or tuple(value.shape) == (
1,
k_h,
seqlen_q,
d,
), f"value shape {value.shape} mismatch!"
assert (
nl.program_ndim() == 2
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
batch_id = nl.program_id(axis=0)
head_id = nl.program_id(axis=1)
(num_active_blocks, ) = block_tables.shape
context_kv_len = num_active_blocks * block_size
assert (
LARGE_TILE_SZ % B_F_SIZE == 0
), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p"
assert (context_kv_len % LARGE_TILE_SZ == 0
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
assert is_power_of_2(
num_blocks_per_large_tile
), f"{num_blocks_per_large_tile=} is expected of be power of 2"
if seqlen_q > B_F_SIZE:
MAX_REDUCTION_TILE = 2048
if seqlen_q // 2 > MAX_REDUCTION_TILE:
assert (
seqlen_q % MAX_REDUCTION_TILE == 0
), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}"
else:
assert (seqlen_q % B_F_SIZE == 0
), f"{seqlen_q=} should be divisible by {B_F_SIZE=})"
kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype
acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
softmax_scale = softmax_scale or (1.0 / (d**0.5))
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
o = nl.ndarray((b, h, seqlen_q, d),
dtype=query.dtype,
buffer=nl.shared_hbm)
hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = (
None,
None,
None,
None,
)
if return_debug_tensors:
hbm_l_buffer = nl.ndarray((b, h, seqlen_q),
dtype=acc_type,
buffer=nl.shared_hbm)
hbm_m_buffer = nl.ndarray((b, h, seqlen_q),
dtype=acc_type,
buffer=nl.shared_hbm)
hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q),
dtype=acc_type,
buffer=nl.shared_hbm)
qk_res_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
block_tables_sbuf = load_block_tables(
block_tables_hbm=block_tables,
num_tiles=num_large_k_tile,
num_blocks_per_tile=num_blocks_per_large_tile,
)
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
if num_blocks_per_large_tile < B_P_SIZE:
# we checked num_blocks_per_tile is a power of 2
assert B_P_SIZE % num_blocks_per_large_tile == 0
block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile
# We assume block_size >= block_size_tiling_factor
assert block_size % block_size_tiling_factor == 0
else:
block_size_tiling_factor = 1
tiled_block_size = block_size // block_size_tiling_factor
# Indirect DMA load must be placed along Partition Dimension
block_tables_sbuf = transform_block_tables_for_indirect_load(
block_tables_sbuf,
block_size_tiling_factor=block_size_tiling_factor,
num_head=k_h,
head_id=head_id,
)
# Flatten KV cache to be 3D for loading into SBUF
new_cache_shape = (
2,
num_blocks * k_h * block_size_tiling_factor,
tiled_block_size * d,
)
kv_cache = kv_cache.reshape(new_cache_shape)
# Global Flash Attention accumulators
o_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
l_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
m_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
cur_k_tile = nl.ndarray(
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype,
)
cur_v_tile = nl.ndarray(
(par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE),
dtype=kernel_dtype,
)
load_kv_tile_from_cache(
cur_k_tile=cur_k_tile,
cur_v_tile=cur_v_tile,
kv_cache=kv_cache,
block_tables=block_tables_sbuf,
large_k_tile_idx=large_k_tile_idx,
num_blocks_per_large_tile=num_blocks_per_large_tile,
tiled_block_size=tiled_block_size,
B_P_SIZE=B_P_SIZE,
B_D_SIZE=B_D_SIZE,
)
for i in nl.affine_range(n_tile_q):
cur_mask = nl.load(mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ),
])
for i_q_h in nl.affine_range(q_h_per_k_h):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(q_hbm_tile[:,
nl.ds(i *
B_P_SIZE, B_P_SIZE)])
if q_sbuf_tile.dtype != kernel_dtype:
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
q_tile[:, :] = q_sbuf_tile * softmax_scale
_flash_attention_core(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
kernel_dtype=kernel_dtype,
acc_type=acc_type,
tile_mask=cur_mask,
use_causal_mask=False,
q_tile_idx=i,
initialize=large_k_tile_idx == 0,
LARGE_TILE_SZ=LARGE_TILE_SZ,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
)
# compute attention between input query, key and value
if key is not None and value is not None:
B_F_SIZE = min(seqlen_q, B_F_SIZE)
LARGE_TILE_SZ = seqlen_q
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
cur_v_tile = nl.ndarray(
(par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE),
dtype=kernel_dtype,
)
loaded = nl.load(key[batch_id, head_id, :, :])
if loaded.dtype != kernel_dtype:
loaded = nl.copy(loaded, dtype=kernel_dtype)
cur_k_tile[:, :] = loaded
v_hbm_tile = value[batch_id, head_id]
for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
load_v_tile(
v_hbm_tile=v_hbm_tile,
cur_v_tile=cur_v_tile,
large_tile_idx=0,
v_i=v_i,
LARGE_TILE_SZ=LARGE_TILE_SZ,
)
for i in nl.affine_range(n_tile_q):
cur_mask = nl.load(mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(context_kv_len, LARGE_TILE_SZ),
])
for i_q_h in nl.affine_range(q_h_per_k_h):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(q_hbm_tile[:,
nl.ds(i *
B_P_SIZE, B_P_SIZE)])
if q_sbuf_tile.dtype != kernel_dtype:
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
q_tile[:, :] = q_sbuf_tile * softmax_scale
_flash_attention_core(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
kernel_dtype=kernel_dtype,
acc_type=acc_type,
tile_mask=cur_mask,
use_causal_mask=True,
q_tile_idx=i,
initialize=False,
LARGE_TILE_SZ=LARGE_TILE_SZ,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
qk_res_buffer=(qk_res_buffer[i, i_q_h]
if qk_res_buffer is not None else None),
)
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
for i_q_h in nl.affine_range(q_h_per_k_h):
for i in nl.affine_range(n_tile_q):
out = nl.multiply(
o_buffer[i, i_q_h],
nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]),
dtype=kernel_dtype,
)
nl.store(
o[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
:,
],
out,
)
# maximum and summation statistics
if return_debug_tensors:
nl.store(
hbm_m_buffer[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
],
m_buffer[i, i_q_h, :, :],
)
nl.store(
hbm_l_buffer[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
],
l_buffer[i, i_q_h],
)
nl.store(
hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :],
qk_res_buffer[batch_id, i_q_h, :, :],
)
if return_debug_tensors:
return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res
return o
def reorder_context_mask(mask, LARGE_TILE_SZ, block_size):
"""
Reorder the mask to make it compatible with the flash attention kernel.
We vectorize KV cache read to improve DMA utilization. However, the layout
that maximizes DMA bandwidth changes the order tokens are consumed.
The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE,
tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And
each step the engine consumes a column (rather than a row) of B_P_SIZE
tokens. Therefore, the tokens are visited in a strided way.
To make sure mask matches the order tokens are consumed, we need to properly
transpose mask.
"""
total_query_len, total_seq_len = mask.shape
context_kv_len = total_seq_len - total_query_len
B_P_SIZE = 128
assert (LARGE_TILE_SZ
>= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}"
num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size)
tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks
if tiled_block_size > 1:
# Mask reordering is needed when tiled_block_size > 1
device = mask.device
mask = mask.cpu()
context_mask = mask[:, :context_kv_len]
context_mask = context_mask.view(
total_query_len,
context_kv_len // LARGE_TILE_SZ,
num_tiled_blocks // B_P_SIZE,
B_P_SIZE,
tiled_block_size,
)
context_mask = context_mask.transpose(3, 4).reshape(
total_query_len, context_kv_len)
new_mask = mask[:, context_kv_len:]
return torch.concat([context_mask, new_mask], dim=1).to(device)
else:
return mask
def flash_attn_varlen_nkifunc(
query,
key,
value,
kv_cache,
block_table,
attn_mask,
n_kv_head=None,
head_size=None,
LARGE_TILE_SZ=2048,
mixed_precision=True,
):
"""
Compute flash paged attention for variable length sequences.
This function is a wrapper around the flash attention NKI kernel. It takes
in the following arguments:
- query: (1, n_heads, d, seq_q)
- key: (1, n_kv_heads, d, seq_k)
- value: (1, n_kv_heads, seq_v, d)
- kv_cache: (2, n_blocks, n_kv_heads, block_size, d)
- block_tables: (n_active_blocks, )
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
Notes:
- attn_mask must be reordered outside using `reorder_context_mask`
- Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d)
for better DMA throughput
"""
if n_kv_head is None:
n_kv_head = kv_cache.shape[2]
assert kv_cache.shape[0] == 2
assert kv_cache.shape[2] == n_kv_head
if head_size is None:
head_size = kv_cache.shape[-1]
kwargs = dict(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
block_tables=block_table,
mask=attn_mask,
softmax_scale=1.0 / (head_size**0.5),
mixed_precision=mixed_precision,
LARGE_TILE_SZ=LARGE_TILE_SZ,
)
o = flash_paged_attention[1, n_kv_head](**kwargs)
return o
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""
Writes key-value pairs to the KV cache at specified positions.
Args:
key (torch.Tensor): Key tensor with shape
(num_tokens, n_kv_head, d_head)
value (torch.Tensor): Value tensor with shape
(num_tokens, n_kv_head, d_head)
kv_cache (torch.Tensor): Key/value cache tensor with shape
(2, num_blocks, n_kv_head, block_size, d_head)
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
with shape (num_tokens)
Returns:
None: Updates the kv_cache tensor in-place
"""
block_size = kv_cache.size(3)
n_kv_head = key.size(1)
# Calculate indices with explicit floor division
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_offsets = slot_mapping % block_size
# Create the head indices tensor
head_indices = torch.arange(n_kv_head, device=key.device)
# Update caches using index_put_
kv_cache.index_put_(
(torch.tensor([0], device=key.device), block_indices[:, None],
head_indices[None, :], block_offsets[:, None]), key)
kv_cache.index_put_(
(torch.tensor([1], device=key.device), block_indices[:, None],
head_indices[None, :], block_offsets[:, None]), value)
......@@ -6,9 +6,14 @@ from typing import List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd
......
......@@ -15,6 +15,7 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8
# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)
float8_info = torch.finfo(current_platform.fp8_dtype())
# Here's an example autotuner config for this kernel. This config does provide
......@@ -43,6 +44,7 @@ def _fwd_kernel(Q,
sm_scale,
k_scale,
v_scale,
out_scale_inv,
B_Start_Loc,
B_Seqlen,
x: tl.constexpr,
......@@ -82,8 +84,11 @@ def _fwd_kernel(Q,
num_unroll_request: tl.constexpr,
SKIP_DECODE: tl.constexpr,
USE_SINKS: tl.constexpr,
USE_FP8: tl.constexpr,
MAX_Q_LEN: tl.constexpr = 0,
MAX_CTX_LEN: tl.constexpr = 0):
MAX_CTX_LEN: tl.constexpr = 0,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
......@@ -146,7 +151,7 @@ def _fwd_kernel(Q,
start_n = tl.multiple_of(start_n, BLOCK_SIZE)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
(start_n // BLOCK_SIZE) * stride_b_loc_s)
(start_n // BLOCK_SIZE) * stride_b_loc_s).to(tl.int64)
# [D,BLOCK_SIZE]
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
......@@ -284,6 +289,9 @@ def _fwd_kernel(Q,
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
......@@ -367,7 +375,7 @@ def _fwd_kernel_flash_attn_v2(
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
other=0).to(tl.int64)
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
......@@ -575,7 +583,7 @@ def _fwd_kernel_alibi(
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
other=0).to(tl.int64)
off_k = (
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
......@@ -743,6 +751,7 @@ def context_attention_fwd(q,
sliding_window=None,
sm_scale=None,
skip_decode=False,
fp8_out_scale=None,
sinks=None):
q_dtype_is_f32 = q.dtype is torch.float32
......@@ -793,6 +802,7 @@ def context_attention_fwd(q,
if alibi_slopes is not None:
assert sinks is None, "Sinks arg is not supported with alibi"
assert fp8_out_scale is None, "FP8 output not supported with alibi"
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
# if q.dtype is torch.float32:
......@@ -870,6 +880,7 @@ def context_attention_fwd(q,
sm_scale,
k_scale,
v_scale,
1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0,
b_start_loc,
b_seq_len,
k_cache.shape[4],
......@@ -905,6 +916,7 @@ def context_attention_fwd(q,
BLOCK_DMODEL_PADDED=Lk_padded,
SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode,
USE_FP8=fp8_out_scale is not None,
BLOCK_M=128,
BLOCK_N=64,
num_unroll_cache=4,
......
......@@ -10,9 +10,11 @@
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
logger = init_logger(__name__)
float8_info = torch.finfo(current_platform.fp8_dtype())
@triton.jit
......@@ -48,47 +50,51 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
@triton.jit
def kernel_unified_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
scale, # float32
k_scale, # float32
v_scale, # float32
softcap, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
qq_bias_stride_0: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
qq_bias_ptr, # [num_query_tokens, num_query_tokens]
scale, # float32
k_scale, # float32
v_scale, # float32
out_scale, # float32
softcap, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
qq_bias_stride_0: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
USE_FP8: tl.constexpr, # bool
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
......@@ -281,6 +287,9 @@ def kernel_unified_attention_2d(
# epilogue
acc = acc / L[:, None]
if USE_FP8:
acc = acc * tl.load(out_scale)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
output_offset = (query_offset_0[:, None] * output_stride_0 +
query_offset_1[:, None] * output_stride_1 +
......@@ -552,23 +561,27 @@ def kernel_unified_attention_3d(
@triton.jit
def reduce_segments(
output_ptr, # [num_tokens, num_query_heads, head_size]
segm_output_ptr,
#[num_tokens, num_query_heads, max_num_segments, head_size]
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
seq_lens_ptr, # [num_seqs]
num_seqs, # int
num_query_heads: tl.constexpr, # int
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
block_table_stride: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int, must be power of 2
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
output_ptr, # [num_tokens, num_query_heads, head_size]
segm_output_ptr,
#[num_tokens, num_query_heads, max_num_segments, head_size]
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
seq_lens_ptr, # [num_seqs]
num_seqs, # int
num_query_heads: tl.constexpr, # int
out_scale_inv, # float32
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
block_table_stride: tl.int64, # int
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int, must be power of 2
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
USE_FP8: tl.constexpr, # bool
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
query_token_idx = tl.program_id(0)
query_head_idx = tl.program_id(1)
......@@ -624,6 +637,10 @@ def reduce_segments(
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
# write result
output_offset = (query_token_idx * output_stride_0 +
query_head_idx * output_stride_1 +
......@@ -649,6 +666,7 @@ def unified_attention(
k_descale,
v_descale,
alibi_slopes=None,
output_scale=None,
qq_bias=None,
# Optional tensor for sinks
sinks=None,
......@@ -674,7 +692,8 @@ def unified_attention(
num_queries_per_kv = num_query_heads // num_kv_heads
head_size = q.shape[2]
BLOCK_M = 16
BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(
num_queries_per_kv)
BLOCK_Q = BLOCK_M // num_queries_per_kv
# Ideally we would launch with kernel with:
......@@ -706,6 +725,7 @@ def unified_attention(
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
out_scale=1 / output_scale if output_scale is not None else 1.0,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
......@@ -735,6 +755,7 @@ def unified_attention(
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
)
else:
# for initial version, NUM_SEGMENTS = 16 is chosen as a default
......@@ -818,6 +839,8 @@ def unified_attention(
seq_lens_ptr=seqused_k,
num_seqs=num_seqs,
num_query_heads=num_query_heads,
out_scale_inv=1 /
output_scale if output_scale is not None else 1.0,
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
block_table_stride=block_table.stride(0),
......@@ -827,4 +850,5 @@ def unified_attention(
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
USE_FP8=output_scale is not None,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment