Unverified Commit 752c6ade authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Deprecate BlockSparse Attention & Phi3-Small (#21217)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 881e3cbe
......@@ -108,7 +108,6 @@ fi
if [[ $commands == *" kernels/attention"* ]]; then
commands="${commands} \
--ignore=kernels/attention/test_attention_selector.py \
--ignore=kernels/attention/test_blocksparse_attention.py \
--ignore=kernels/attention/test_encoder_decoder_attn.py \
--ignore=kernels/attention/test_flash_attn.py \
--ignore=kernels/attention/test_flashinfer.py \
......
......@@ -376,7 +376,6 @@ Specified using `--task generate`.
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |
| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ |
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from typing import Optional
import pytest
import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from vllm import _custom_ops as ops
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn)
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
# MAX_SEQ_LEN = 2771
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512
DTYPES = [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [3] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40)] # Arbitrary values for testing
HEAD_SIZES = [64, 112]
BLOCK_SIZES = [16]
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8"]
SEEDS = [0]
CUDA_DEVICES = ['cuda:0']
BLOCKSPARSE_LOCAL_BLOCKS = [16]
BLOCKSPARSE_VERT_STRIDES = [8]
BLOCKSPARSE_BLOCK_SIZES = [64]
BLOCKSPARSE_HEADS_SLIDINGS = [2, -1]
BLOCKSPARSE_HOMO_HEADS = [True, False]
def ref_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
if attn_mask is not None:
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out
def ref_single_query_cached_kv_attention(
output: torch.Tensor,
query: torch.Tensor,
num_queries_per_kv: int,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
alibi_slopes: Optional[torch.Tensor],
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 1,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
num_query_heads = query.shape[1]
num_kv_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
block_size = value_cache.shape[3]
num_seqs = query.shape[0]
block_tables_lst = block_tables.cpu().tolist()
seq_lens_lst = seq_lens.cpu().tolist()
for i in range(num_seqs):
q = query[i].unsqueeze(0)
block_table = block_tables_lst[i]
seq_len = int(seq_lens_lst[i])
keys_lst: list[torch.Tensor] = []
values_lst: list[torch.Tensor] = []
for j in range(seq_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_kv_heads, head_size)
keys_lst.append(k)
v = value_cache[block_number, :, :, block_offset]
values_lst.append(v)
keys = torch.stack(keys_lst, dim=0)
values = torch.stack(values_lst, dim=0)
if num_queries_per_kv > 1:
# Handle MQA and GQA
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
alibi_bias = None
if alibi_slopes is not None:
# Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(seq_len).int()
alibi_bias = (position_ids - seq_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)
if blocksparse_vert_stride >= 1:
bsize = blocksparse_block_size
hsliding = blocksparse_head_sliding_step
vert = blocksparse_vert_stride
locals = blocksparse_local_blocks
qb = (seq_len - 1) // bsize
attn_mask = q.new_zeros(
(num_query_heads, 1, seq_len)).float() - torch.inf
for h in range(num_query_heads):
if hsliding >= 0: # slide with q heads
bs_offset = (tp_rank * num_query_heads + h) * hsliding + 1
else: # slide with kv heads
bs_offset = (tp_rank * num_kv_heads +
h // num_queries_per_kv) * (-hsliding) + 1
for kb in range(qb + 1):
kj = kb * bsize
if (qb - kb) < locals or \
(kb + bs_offset) % vert == 0:
attn_mask[h, 0, kj:min(kj + bsize, seq_len)] = 0
if alibi_bias is not None:
attn_mask += alibi_bias
else:
attn_mask = alibi_bias
out = ref_masked_attention(q, keys, values, scale, attn_mask=attn_mask)
out = out.view(num_query_heads, head_size)
output[i].copy_(out, non_blocking=True)
@pytest.mark.parametrize("version", ["v1", "v2"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS)
@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES)
@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES)
@pytest.mark.parametrize("blocksparse_head_sliding_step",
BLOCKSPARSE_HEADS_SLIDINGS)
def test_paged_attention(
kv_cache_factory,
version: str,
num_seqs: int,
num_heads: tuple[int, int],
head_size: int,
use_alibi: bool,
block_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
seed: int,
device: str,
blocksparse_local_blocks: int,
blocksparse_vert_stride: int,
blocksparse_block_size: int,
blocksparse_head_sliding_step: int,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.rand(num_query_heads, dtype=torch.float)
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
# Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size,
kv_cache_dtype, dtype, seed,
device)
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
tp_rank = 0
# Call the paged attention kernel.
output = torch.empty_like(query)
if version == "v1":
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride,
blocksparse_block_size=blocksparse_block_size,
blocksparse_head_sliding_step=blocksparse_head_sliding_step,
)
elif version == "v2":
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty(
size=(num_seqs, num_heads, num_partitions, head_size),
dtype=output.dtype,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, num_partitions),
dtype=torch.float32,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride,
blocksparse_block_size=blocksparse_block_size,
blocksparse_head_sliding_step=blocksparse_head_sliding_step,
)
else:
raise AssertionError(f"Unknown version: {version}")
# Run the reference implementation.
if kv_cache_dtype == "fp8":
# Convert cache data back to dtype.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
block_size, x)
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=device)
ops.convert_fp8(dequantized_key_cache, key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
ops.convert_fp8(dequantized_value_cache, value_cache)
value_cache = dequantized_value_cache
ref_output = torch.empty_like(query)
ref_single_query_cached_kv_attention(
ref_output,
query,
num_queries_per_kv,
key_cache,
value_cache,
block_tables,
seq_lens,
scale,
alibi_slopes,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8":
atol, rtol = 1e-2, 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
def ref_multi_query_kv_attention(
cu_seq_lens: list[int],
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
dtype: torch.dtype,
) -> torch.Tensor:
num_seqs = len(cu_seq_lens) - 1
ref_outputs = []
for i in range(num_seqs):
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
seq_len = end_idx - start_idx
# Create attention mask.
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype)
ref_output = ref_masked_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
scale,
attn_mask=attn_mask,
)
ref_outputs.append(ref_output)
ref_output = torch.cat(ref_outputs, dim=0)
return ref_output
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS)
@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES)
@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES)
@pytest.mark.parametrize("blocksparse_homo_heads", BLOCKSPARSE_HOMO_HEADS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_varlen_blocksparse_attention_prefill(
num_seqs: int,
num_heads: tuple[int, int],
head_size: int,
blocksparse_local_blocks: int,
blocksparse_vert_stride: int,
blocksparse_block_size: int,
blocksparse_homo_heads: bool,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
max_len = min(MAX_SEQ_LEN, 4096)
seq_lens = random.sample(range(1, max_len), num_seqs)
cu_seq_lens = torch.cumsum(torch.tensor([0] + seq_lens), dim=0)
num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
qkv = torch.empty(num_tokens,
num_query_heads + 2 * num_kv_heads,
head_size,
dtype=dtype)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
bs_attn_op = LocalStridedBlockSparseAttn(
num_query_heads,
max_len,
local_blocks=blocksparse_local_blocks,
vert_stride=blocksparse_vert_stride,
block_size=blocksparse_block_size,
device=device,
dtype=dtype,
homo_head=blocksparse_homo_heads)
output = bs_attn_op(query,
key,
value,
cu_seq_lens.to(device),
sm_scale=scale)
if num_queries_per_kv > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
ref_output = ref_multi_query_kv_attention(
cu_seq_lens.tolist(),
query,
key,
value,
scale,
dtype,
)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
......@@ -33,8 +33,12 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# change the attention backend to triton MLA
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
False, True)
backend = get_attn_backend(576,
torch.bfloat16,
"auto",
16,
False,
use_mla=True)
assert (backend.get_name() == "TRITON_MLA"
or backend.get_name() == "TRITON_MLA_VLLM_V1")
......@@ -42,15 +46,23 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# If use_mla is true
# The selected backend is triton MLA
m.setenv(STR_BACKEND_ENV_VAR, None)
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
False, True)
backend = get_attn_backend(576,
torch.bfloat16,
"auto",
16,
False,
use_mla=True)
assert (backend.get_name() == "TRITON_MLA"
or backend.get_name() == "TRITON_MLA_VLLM_V1")
# change the attention backend to AITER MLA
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
False, True)
backend = get_attn_backend(576,
torch.bfloat16,
"auto",
1,
False,
use_mla=True)
assert (backend.get_name() == "ROCM_AITER_MLA"
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
......@@ -60,7 +72,11 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# The selected backend is ROCM_AITER_MLA
m.setenv(STR_BACKEND_ENV_VAR, None)
m.setenv("VLLM_ROCM_USE_AITER", "1")
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
False, True)
backend = get_attn_backend(576,
torch.bfloat16,
"auto",
1,
False,
use_mla=True)
assert (backend.get_name() == "ROCM_AITER_MLA"
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
......@@ -247,10 +247,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
# Blocksparse attention not supported in V1 yet
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
trust_remote_code=True,
v0_only=True),
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
trust_remote_code=True,
v0_only=True,
......
......@@ -269,7 +269,6 @@ class AttentionImpl(ABC, Generic[T]):
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
@dataclass
class BlocksparseParams:
max_seqlen: int
# Num q heads per tensor-parallel rank/partition
num_heads: int # per TP partition
# Num kv heads per tensor-parallel rank/partition
num_kv_heads: int
# block size used for blocksparse attention.
# This is the block_size used in `local_blocks`, `vert_stride`.
block_size: int
# Number of blocks for local attention, i.e., number of
# local attended tokens / `sparse_block_size`
local_blocks: int
# Attend to one block per every `vert_stride` blocks.
# Controlling the sparsity
vert_stride: int
"""
If to use the same vertical stride offset for all heads,
i.e., attend to the same block of tokens on all heads.
By default, it is False, i.e., attention on the non-local
blocks depends on the `head_idx`, that is on
blocks satisfying
`(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
`block_idx = position_id // sparse_block_size`.
See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
for more detail.
"""
homo_head: bool = False
# If within a group, the kv offsets that each q attends is the same or no.
homo_head_group: bool = False
# Decided by homo_head and homo_head group
head_sliding_step: int = field(init=False)
# range of q heads to for a TP rank
active_head_range: Tuple = field(init=False)
def __post_init__(self):
assert self.block_size > 0
assert self.local_blocks >= 0
assert self.vert_stride >= 1
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
total_heads = tp_size * self.num_heads
total_kv_heads = tp_size * self.num_kv_heads
if self.homo_head:
self.head_sliding_step = 0
elif self.homo_head_group:
head_sliding_step = get_head_sliding_step(total_kv_heads,
self.vert_stride)
# negative indicates sliding along kv heads, i.e., homo q group
self.head_sliding_step = -head_sliding_step
else:
self.head_sliding_step = get_head_sliding_step(
total_heads, self.vert_stride)
self.active_head_range = (
tp_rank * self.num_heads,
(tp_rank + 1) * self.num_heads,
)
class BlocksparseFlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "BLOCK_SPARSE_FLASH_ATTN"
@staticmethod
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
return BlocksparseFlashAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return BlocksparseFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
return BlocksparseFlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class BlocksparseFlashAttentionMetadata(AttentionMetadata):
"""A copy of Metadata for FlashAttentionBackend,
to avoid having to install flash_attn.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# Max number of query tokens for among request in the batch.
max_decode_query_len: Optional[int] = None
_cached_prefill_metadata: Optional[
"BlocksparseFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional[
"BlocksparseFlashAttentionMetadata"] = None
@property
def prefill_metadata(
self) -> Optional["BlocksparseFlashAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.block_tables is not None
assert self.seq_start_loc is not None
self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.block_tables is not None
assert self.seq_lens_tensor is not None
self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
class BlocksparseFlashAttentionMetadataBuilder(
CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):
_metadata_cls = BlocksparseFlashAttentionMetadata
class BlocksparseFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0 "
"BLOCK_SPARSE_FLASH_ATTN Backend.")
assert blocksparse_params is not None
assert alibi_slopes is None, ValueError(
"Alibi not support for blocksparse flash attention.")
assert sliding_window is None, ValueError(
"sliding_window is invalid for blocksparse attention.")
assert logits_soft_cap is None, ValueError(
"logits_soft_cap is invalid for blocksparse attention.")
if "num_heads" not in blocksparse_params:
blocksparse_params["num_heads"] = num_heads
if "num_kv_heads" not in blocksparse_params:
blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
self.blocksparse_params = BlocksparseParams(**blocksparse_params)
self.kv_cache_dtype = kv_cache_dtype
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.alibi_slopes = alibi_slopes
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.local_blocks = self.blocksparse_params.local_blocks
self.vert_stride = self.blocksparse_params.vert_stride
self.sparse_block_size = self.blocksparse_params.block_size
self.head_sliding_step = self.blocksparse_params.head_sliding_step
supported_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
total_num_heads = num_heads * self.tp_size
self.bs_attn = LocalStridedBlockSparseAttn(
total_num_heads,
self.blocksparse_params.max_seqlen,
self.blocksparse_params.local_blocks,
self.blocksparse_params.vert_stride,
self.blocksparse_params.block_size,
homo_head=self.blocksparse_params.homo_head,
active_head_range=self.blocksparse_params.active_head_range,
)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl")
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for BlocksparseFlashAttentionImpl")
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
assert kv_cache.numel() == 0 \
or prefill_meta.block_tables is None \
or prefill_meta.block_tables.numel() == 0, \
"Does not support prefix-enabled attention."
output = self.bs_attn(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
sm_scale=self.scale,
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output = PagedAttention.forward_decode(
query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
self.blocksparse_params.max_seqlen,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
layer._k_scale,
layer._v_scale,
tp_rank=self.tp_rank,
blocksparse_local_blocks=self.local_blocks,
blocksparse_vert_stride=self.vert_stride,
blocksparse_block_size=self.sparse_block_size,
blocksparse_head_sliding_step=self.head_sliding_step,
)
assert output is not None
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
......@@ -667,7 +667,6 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
......@@ -680,9 +679,6 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
differential_flash_attention_config
self.used_shared_kv_cache = kv_sharing_target_layer_name is not None
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
if use_irope:
logger.warning(
"Using irope in V0 is not supported yet, it will fall back "
......
......@@ -287,7 +287,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
......
......@@ -4,7 +4,7 @@
from collections import defaultdict
from dataclasses import dataclass
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
import torch
......@@ -615,7 +615,6 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
......@@ -624,9 +623,6 @@ class FlashAttentionImpl(AttentionImpl):
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0 "
"FLASH_ATTN backend.")
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
if use_irope:
logger.warning(
"Using irope in V0 is not supported yet, it will fall back "
......
......@@ -999,7 +999,6 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
......
......@@ -3,7 +3,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
import torch
......@@ -181,7 +181,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str] = None,
......@@ -189,20 +188,17 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
assert is_flashmla_supported(), \
"FlashMLA is not supported on this device"
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
......
......@@ -997,7 +997,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
......
......@@ -3,7 +3,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Type, Union
from typing import TYPE_CHECKING, Optional, Type, Union
import torch
......@@ -367,7 +367,6 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
......@@ -375,17 +374,14 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
"alibi_slopes, sliding_window, logits_soft_cap")
from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func
......
......@@ -4,7 +4,7 @@
import itertools
from dataclasses import dataclass
from functools import cache
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
import torch
......@@ -494,7 +494,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
......@@ -507,9 +506,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
logger.warning_once(
"Using irope in ROCm Flash Attention is not supported yet, it "
"will fail back to global attention for long context.")
if blocksparse_params is not None:
raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.")
if use_irope:
logger.warning(
"Using irope in V0 is not supported yet, it will fall back "
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Dict, List, Optional, Type
from typing import List, Optional, Type
import torch
......@@ -35,7 +35,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
......@@ -43,17 +42,14 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with xFormers and PagedAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Dict, List, Optional, Tuple, Type
import torch
from xformers import ops as xops
......@@ -387,7 +387,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
......@@ -396,9 +395,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0 "
"XFORMERS backend.")
if blocksparse_params is not None:
raise ValueError(
"XFormers does not support block-sparse attention.")
if logits_soft_cap is not None:
logger.warning_once("XFormers does not support logits soft cap. "
"Outputs may be slightly off.")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from typing import Any, Dict, List, Optional
from typing import List, Optional
import torch
import torch.nn as nn
......@@ -74,7 +74,6 @@ class Attention(nn.Module):
alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
use_mla: bool = False,
......@@ -163,12 +162,11 @@ class Attention(nn.Module):
kv_cache_dtype,
block_size,
is_attention_free,
blocksparse_params is not None,
use_mla=use_mla)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
def blocksparse_flash_attn_varlen_fwd(
q,
k,
v, # (#tokens, n_heads, head_size)
cu_seqlens_k,
cu_seqlens_q,
sm_scale,
sparse_layout,
*,
block_size=64,
q_block_size=None,
max_seqlen=None):
# split q to blocks
assert isinstance(sparse_layout, (list, tuple))
_, n_heads, head_size = q.shape
batch_size = cu_seqlens_k.size(0) - 1
q_block_size = q_block_size or block_size
assert q.dim() == k.dim() == v.dim() == 3
assert q.size(1) % k.size(1) == 0
assert q.size(2) == k.size(2)
# TODO(linxihui): allow k, v to have different head_size
assert k.shape == v.shape
assert cu_seqlens_k.dim() == 1
q_k_ratio = q.size(1) // k.size(1)
if cu_seqlens_q is None:
if q.size(0) == batch_size: # decoding only
cu_seqlens_q = torch.arange(
0,
batch_size + 1,
dtype=cu_seqlens_k.dtype,
device=cu_seqlens_k.device,
)
elif q.size(0) == k.size(0):
cu_seqlens_q = cu_seqlens_k
else:
raise ValueError("cu_seqlens_q must be specified\
if it mix of prefilling and decoding.")
else:
assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
# switch to use cpu to avoid too many kernel launches when iterated over
q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (
"length of q should either be 1 (decoding) or same as k (prefilling).")
if max_seqlen:
assert k_lens.max() <= max_seqlen
n_blocks = (q_lens + q_block_size - 1) // q_block_size
q_batch_ids = torch.tensor(
[i for i, n in enumerate(n_blocks) for _ in range(n)],
dtype=cu_seqlens_q.dtype,
device=cu_seqlens_q.device,
)
q_start_sids = torch.tensor(
[i * q_block_size for n in n_blocks for i in range(n)],
dtype=cu_seqlens_q.dtype,
device=cu_seqlens_q.device,
)
out = q.new_empty(q.shape)
cu_seqlens_q = cu_seqlens_q.contiguous()
cu_seqlens_k = cu_seqlens_k.contiguous()
layout_crow_indices, layout_col_indices = sparse_layout
block_d = triton.next_power_of_2(head_size)
decoding_only = (q_lens == 1).all().item()
grid = (len(q_start_sids), n_heads, 1)
_fwd_kernel_batch_inference[grid](
q,
k,
v,
out,
sm_scale,
cu_seqlens_q[:-1],
cu_seqlens_q[1:],
cu_seqlens_k[:-1],
cu_seqlens_k[1:],
q_batch_ids,
q_start_sids,
0,
*q.stride(),
0,
*k.stride(),
0,
*v.stride(),
0,
*out.stride(),
layout_crow_indices,
layout_col_indices,
*layout_crow_indices.stride(),
*layout_col_indices.stride(),
q_k_ratio,
HAS_BATCH_DIM=False,
D_HEAD=head_size,
BLOCK_M=q_block_size,
BLOCK_N=block_size,
BLOCK_D=block_d,
BLOCK_M_LOADING=(16 if decoding_only else
q_block_size), # smaller for decoding
EVEN_D=block_d == head_size,
num_warps=1 if decoding_only else 4,
num_stages=3)
return out
@triton.jit
def _fwd_kernel_inner(
acc,
l_i,
m_i,
q,
Q,
k_block_col_idx,
layout_col_ptr,
layout_col_stride_h,
layout_col_stride_m,
k_ptrs,
v_ptrs,
off_h,
offs_m,
offs_n,
offs_d,
stride_kt,
stride_vt,
sm_scale,
k_seqlen,
past_len,
LAST_K_BLOCK: tl.constexpr,
BLOCK_M_LOADING: tl.constexpr,
BLOCK_N: tl.constexpr,
D_HEAD: tl.constexpr,
EVEN_D: tl.constexpr,
M_LT_N: tl.constexpr,
):
k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +
k_block_col_idx * layout_col_stride_m).to(tl.int32)
start_n = k_block_id * BLOCK_N
if LAST_K_BLOCK:
if EVEN_D:
k = tl.load(
k_ptrs + start_n * stride_kt,
mask=offs_n[None, :] + start_n < k_seqlen,
other=0.0,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kt,
mask=(offs_n[None, :] + start_n < k_seqlen) &
(offs_d[:, None] < D_HEAD),
other=0.0,
)
else:
if EVEN_D:
k = tl.load(k_ptrs + start_n * stride_kt)
else:
k = tl.load(k_ptrs + start_n * stride_kt,
mask=offs_d[:, None] < D_HEAD,
other=0.0)
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK | M_LT_N:
qk += tl.where(
offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),
0,
float("-inf"),
)
# flash-attn2
m_ij = tl.maximum(m_i, tl.max(qk, 1))
p = tl.math.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
# update m_i
m_i = m_ij
l_i = l_i * alpha + l_ij
p = p.to(Q.dtype.element_ty)
# update acc
if LAST_K_BLOCK:
if EVEN_D:
v = tl.load(
v_ptrs + start_n * stride_vt,
mask=offs_n[:, None] + start_n < k_seqlen,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vt,
mask=(offs_n[:, None] + start_n < k_seqlen) &
(offs_d[None, :] < D_HEAD),
other=0.0,
)
else:
if EVEN_D:
v = tl.load(v_ptrs + start_n * stride_vt)
else:
v = tl.load(v_ptrs + start_n * stride_vt,
mask=offs_d[None, :] < D_HEAD,
other=0.0)
acc += tl.dot(p, v)
return acc, l_i, m_i
@triton.heuristics({
"M_LT_N":
lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
})
@triton.jit
def _fwd_kernel_batch_inference(
Q,
K,
V,
Out,
sm_scale,
q_batch_starts,
q_batch_ends,
k_batch_starts,
k_batch_ends,
q_batch_ids,
q_start_sids,
stride_qb,
stride_qt,
stride_qh,
stride_qd,
stride_kb,
stride_kt,
stride_kh,
stride_kd,
stride_vb,
stride_vt,
stride_vh,
stride_vd,
stride_ob,
stride_ot,
stride_oh,
stride_od,
layout_crow_ptr,
layout_col_ptr,
layout_crow_stride_h,
layout_crow_stride_m,
layout_col_stride_h,
layout_col_stride_m,
q_k_ratio,
HAS_BATCH_DIM: tl.constexpr,
D_HEAD: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
BLOCK_M_LOADING: tl.constexpr,
EVEN_D: tl.constexpr,
M_LT_N: tl.constexpr,
):
"""
NOTATION:
pid: position id
sid: storage id
sbid: storage block id
pbid: position block id
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
TODO(linxihui):
Optimize grouped-attn
"""
off_zm = tl.program_id(0)
off_h = tl.program_id(1)
off_h_for_kv = off_h // q_k_ratio
if HAS_BATCH_DIM:
off_z = tl.program_id(2)
Q += off_z * stride_qb
K += off_z * stride_kb
V += off_z * stride_vb
Out += off_z * stride_ob
start_m = off_zm
q_start_sid = start_m * BLOCK_M # always 0 for decoding
else:
off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
q_start_sid = tl.load(q_start_sids + off_zm)
start_m = q_start_sid // BLOCK_M # q_sbid
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_D)
q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
past_len = k_seqlen - q_seqlen
Q += q_cu_start * stride_qt + off_h * stride_qh
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
Out += q_cu_start * stride_ot + off_h * stride_oh
q_pbid = (past_len + q_start_sid) // BLOCK_M
if EVEN_D:
q = tl.load(
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
mask=offs_m[:, None] < q_seqlen,
other=0.0,
)
else:
q = tl.load(
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
other=0.0,
)
sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
q_pbid * layout_crow_stride_m)
# TODO(linxihui): load at once, with any Triton version
# that supports `tl.split`, e.g., Triton 3.0
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
sm_scale *= (
1.44269504 # 1/log2 as we use base2 for exponential and logarithm
)
for k_block_col_idx in range(k_block_start, k_block_end - 1):
acc, l_i, m_i = _fwd_kernel_inner(
acc,
l_i,
m_i,
q,
Q,
k_block_col_idx,
layout_col_ptr,
layout_col_stride_h,
layout_col_stride_m,
k_ptrs,
v_ptrs,
off_h,
offs_m,
offs_n,
offs_d,
stride_kt,
stride_vt,
sm_scale,
k_seqlen,
past_len,
False,
BLOCK_M_LOADING,
BLOCK_N,
D_HEAD,
EVEN_D,
M_LT_N,
)
acc, l_i, m_i = _fwd_kernel_inner(
acc,
l_i,
m_i,
q,
Q,
k_block_end - 1,
layout_col_ptr,
layout_col_stride_h,
layout_col_stride_m,
k_ptrs,
v_ptrs,
off_h,
offs_m,
offs_n,
offs_d,
stride_kt,
stride_vt,
sm_scale,
k_seqlen,
past_len,
True,
BLOCK_M_LOADING,
BLOCK_N,
D_HEAD,
EVEN_D,
M_LT_N,
)
# flash-attn 2
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
# write output
if EVEN_D:
tl.store(
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
acc,
mask=offs_m[:, None] < q_seqlen,
)
else:
tl.store(
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
acc,
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment