Commit 711aa9d5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.0' into v0.10.0-dev

parents 751c492c 6d8d0a24
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import itertools import itertools
from dataclasses import dataclass from dataclasses import dataclass
from functools import cache from functools import cache
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, List, Optional, Tuple, Type
import torch import torch
import triton import triton
...@@ -21,7 +21,9 @@ from vllm.attention.ops.paged_attn import (PagedAttention, ...@@ -21,7 +21,9 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.utils import SUPPORT_TC, gpuname from vllm.utils import SUPPORT_TC, gpuname
...@@ -502,21 +504,18 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -502,21 +504,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
alibi_slopes: Optional[List[float]], alibi_slopes: Optional[List[float]],
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None, kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False, use_irope: bool = False,
) -> None: ) -> None:
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.") raise NotImplementedError("KV sharing is not supported in V0 "
"ROCM_FLASH backend.")
if use_irope: if use_irope:
logger.warning_once( logger.warning_once(
"Using irope in ROCm Flash Attention is not supported yet, it " "Using irope in ROCm Flash Attention is not supported yet, it "
"will fail back to global attention for long context.") "will fail back to global attention for long context.")
if blocksparse_params is not None:
raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.")
if use_irope: if use_irope:
logger.warning( logger.warning(
"Using irope in V0 is not supported yet, it will fall back " "Using irope in V0 is not supported yet, it will fall back "
...@@ -616,10 +615,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -616,10 +615,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim)) head_dim))
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]): group_shape: GroupShape):
if self.use_triton_flash_attn: if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype( return dtype == current_platform.fp8_dtype(
) and static and group_shape == (-1, -1) # per-tensor ) and static and group_shape == GroupShape.PER_TENSOR
# Only supported in the Triton backend # Only supported in the Triton backend
return False return False
......
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from typing import List, Optional, Type, Any, Dict
from typing import Any, Dict, List, Optional, Type
from .triton_config import get_nearest_config, get_attention_mla_configs, get_config, get_attention_mla_configs_json from .triton_config import get_nearest_config, get_attention_mla_configs, get_config, get_attention_mla_configs_json
import torch import torch
...@@ -42,7 +41,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -42,7 +41,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
alibi_slopes: Optional[List[float]], alibi_slopes: Optional[List[float]],
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float], logits_soft_cap: Optional[float],
attn_type: str, attn_type: str,
kv_sharing_target_layer_name: Optional[str], kv_sharing_target_layer_name: Optional[str],
...@@ -50,17 +48,14 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -50,17 +48,14 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
**mla_args) -> None: **mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads, super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype, alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args) kv_sharing_target_layer_name, **mla_args)
unsupported_features = [ unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features): if any(unsupported_features):
raise NotImplementedError( raise NotImplementedError(
"TritonMLAImpl does not support one of the following: " "TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, " "alibi_slopes, sliding_window, logits_soft_cap")
"logits_soft_cap")
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with xFormers and PagedAttention.""" """Attention layer with xFormers and PagedAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
import torch import torch
from xformers import ops as xops from xformers import ops as xops
...@@ -394,17 +394,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -394,17 +394,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
alibi_slopes: Optional[List[float]], alibi_slopes: Optional[List[float]],
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None, kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False, use_irope: bool = False,
) -> None: ) -> None:
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.") raise NotImplementedError("KV sharing is not supported in V0 "
if blocksparse_params is not None: "XFORMERS backend.")
raise ValueError(
"XFormers does not support block-sparse attention.")
if logits_soft_cap is not None: if logits_soft_cap is not None:
logger.warning_once("XFormers does not support logits soft cap. " logger.warning_once("XFormers does not support logits soft cap. "
"Outputs may be slightly off.") "Outputs may be slightly off.")
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer.""" """Attention layer."""
from typing import Any, Dict, List, Optional from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -10,18 +10,47 @@ import torch.nn.functional as F ...@@ -10,18 +10,47 @@ import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionType from vllm.attention import AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group, has_kv_transfer_group,
is_v1_kv_transfer_group) is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import _Backend, current_platform from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
def check_xformers_availability():
global USE_XFORMERS_OPS
if USE_XFORMERS_OPS is not None:
return USE_XFORMERS_OPS
if current_platform.is_cuda() and current_platform.has_device_capability(
100):
# Xformers FA is not compatible with B200
USE_XFORMERS_OPS = False
else:
try:
from importlib.util import find_spec
find_spec("xformers.ops")
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
# the warning only needs to be shown once
if not USE_XFORMERS_OPS:
logger.warning("Xformers is not available, falling back.")
return USE_XFORMERS_OPS
class Attention(nn.Module): class Attention(nn.Module):
...@@ -45,7 +74,6 @@ class Attention(nn.Module): ...@@ -45,7 +74,6 @@ class Attention(nn.Module):
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None, per_layer_sliding_window: Optional[int] = None,
use_mla: bool = False, use_mla: bool = False,
...@@ -109,6 +137,15 @@ class Attention(nn.Module): ...@@ -109,6 +137,15 @@ class Attention(nn.Module):
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window self.sliding_window = sliding_window
# For v1 we have backend agnostic iRoPE (local chunked attention)
# we have to store the flag on the layer so gpu model runner can
# set KVSpec appropriately (and pop it so it doesnt get passed to
# the backends)
if envs.VLLM_USE_V1:
self.use_irope = extra_impl_args.pop("use_irope", False)
else:
self.use_irope = extra_impl_args.get("use_irope", False)
quant_method = quant_config.get_quant_method( quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None self, prefix=prefix) if quant_config else None
if quant_method is not None and not isinstance( if quant_method is not None and not isinstance(
...@@ -134,12 +171,11 @@ class Attention(nn.Module): ...@@ -134,12 +171,11 @@ class Attention(nn.Module):
kv_cache_dtype, kv_cache_dtype,
block_size, block_size,
is_attention_free, is_attention_free,
blocksparse_params is not None,
use_mla=use_mla) use_mla=use_mla)
impl_cls = attn_backend.get_impl_cls() impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype, alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **extra_impl_args) kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(attn_backend.get_name()) self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype self.dtype = dtype
...@@ -160,10 +196,6 @@ class Attention(nn.Module): ...@@ -160,10 +196,6 @@ class Attention(nn.Module):
self.attn_type = attn_type self.attn_type = attn_type
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
if not envs.VLLM_USE_V1:
raise NotImplementedError(
"Cross-layer KV sharing is not supported in V0.")
validate_kv_sharing_target( validate_kv_sharing_target(
prefix, prefix,
kv_sharing_target_layer_name, kv_sharing_target_layer_name,
...@@ -318,6 +350,10 @@ class MultiHeadAttention(nn.Module): ...@@ -318,6 +350,10 @@ class MultiHeadAttention(nn.Module):
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
} else _Backend.TORCH_SDPA } else _Backend.TORCH_SDPA
if (self.attn_backend == _Backend.XFORMERS
and not check_xformers_availability()):
self.attn_backend = _Backend.TORCH_SDPA
def forward( def forward(
self, self,
query: torch.Tensor, query: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
def blocksparse_flash_attn_varlen_fwd(
q,
k,
v, # (#tokens, n_heads, head_size)
cu_seqlens_k,
cu_seqlens_q,
sm_scale,
sparse_layout,
*,
block_size=64,
q_block_size=None,
max_seqlen=None):
# split q to blocks
assert isinstance(sparse_layout, (list, tuple))
_, n_heads, head_size = q.shape
batch_size = cu_seqlens_k.size(0) - 1
q_block_size = q_block_size or block_size
assert q.dim() == k.dim() == v.dim() == 3
assert q.size(1) % k.size(1) == 0
assert q.size(2) == k.size(2)
# TODO(linxihui): allow k, v to have different head_size
assert k.shape == v.shape
assert cu_seqlens_k.dim() == 1
q_k_ratio = q.size(1) // k.size(1)
if cu_seqlens_q is None:
if q.size(0) == batch_size: # decoding only
cu_seqlens_q = torch.arange(
0,
batch_size + 1,
dtype=cu_seqlens_k.dtype,
device=cu_seqlens_k.device,
)
elif q.size(0) == k.size(0):
cu_seqlens_q = cu_seqlens_k
else:
raise ValueError("cu_seqlens_q must be specified\
if it mix of prefilling and decoding.")
else:
assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
# switch to use cpu to avoid too many kernel launches when iterated over
q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (
"length of q should either be 1 (decoding) or same as k (prefilling).")
if max_seqlen:
assert k_lens.max() <= max_seqlen
n_blocks = (q_lens + q_block_size - 1) // q_block_size
q_batch_ids = torch.tensor(
[i for i, n in enumerate(n_blocks) for _ in range(n)],
dtype=cu_seqlens_q.dtype,
device=cu_seqlens_q.device,
)
q_start_sids = torch.tensor(
[i * q_block_size for n in n_blocks for i in range(n)],
dtype=cu_seqlens_q.dtype,
device=cu_seqlens_q.device,
)
out = q.new_empty(q.shape)
cu_seqlens_q = cu_seqlens_q.contiguous()
cu_seqlens_k = cu_seqlens_k.contiguous()
layout_crow_indices, layout_col_indices = sparse_layout
block_d = triton.next_power_of_2(head_size)
decoding_only = (q_lens == 1).all().item()
grid = (len(q_start_sids), n_heads, 1)
_fwd_kernel_batch_inference[grid](
q,
k,
v,
out,
sm_scale,
cu_seqlens_q[:-1],
cu_seqlens_q[1:],
cu_seqlens_k[:-1],
cu_seqlens_k[1:],
q_batch_ids,
q_start_sids,
0,
*q.stride(),
0,
*k.stride(),
0,
*v.stride(),
0,
*out.stride(),
layout_crow_indices,
layout_col_indices,
*layout_crow_indices.stride(),
*layout_col_indices.stride(),
q_k_ratio,
HAS_BATCH_DIM=False,
D_HEAD=head_size,
BLOCK_M=q_block_size,
BLOCK_N=block_size,
BLOCK_D=block_d,
BLOCK_M_LOADING=(16 if decoding_only else
q_block_size), # smaller for decoding
EVEN_D=block_d == head_size,
num_warps=1 if decoding_only else 4,
num_stages=3)
return out
@triton.jit
def _fwd_kernel_inner(
acc,
l_i,
m_i,
q,
Q,
k_block_col_idx,
layout_col_ptr,
layout_col_stride_h,
layout_col_stride_m,
k_ptrs,
v_ptrs,
off_h,
offs_m,
offs_n,
offs_d,
stride_kt,
stride_vt,
sm_scale,
k_seqlen,
past_len,
LAST_K_BLOCK: tl.constexpr,
BLOCK_M_LOADING: tl.constexpr,
BLOCK_N: tl.constexpr,
D_HEAD: tl.constexpr,
EVEN_D: tl.constexpr,
M_LT_N: tl.constexpr,
):
k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +
k_block_col_idx * layout_col_stride_m).to(tl.int32)
start_n = k_block_id * BLOCK_N
if LAST_K_BLOCK:
if EVEN_D:
k = tl.load(
k_ptrs + start_n * stride_kt,
mask=offs_n[None, :] + start_n < k_seqlen,
other=0.0,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kt,
mask=(offs_n[None, :] + start_n < k_seqlen) &
(offs_d[:, None] < D_HEAD),
other=0.0,
)
else:
if EVEN_D:
k = tl.load(k_ptrs + start_n * stride_kt)
else:
k = tl.load(k_ptrs + start_n * stride_kt,
mask=offs_d[:, None] < D_HEAD,
other=0.0)
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
# the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
if LAST_K_BLOCK | M_LT_N:
qk += tl.where(
offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),
0,
float("-inf"),
)
# flash-attn2
m_ij = tl.maximum(m_i, tl.max(qk, 1))
p = tl.math.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
# update m_i
m_i = m_ij
l_i = l_i * alpha + l_ij
p = p.to(Q.dtype.element_ty)
# update acc
if LAST_K_BLOCK:
if EVEN_D:
v = tl.load(
v_ptrs + start_n * stride_vt,
mask=offs_n[:, None] + start_n < k_seqlen,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vt,
mask=(offs_n[:, None] + start_n < k_seqlen) &
(offs_d[None, :] < D_HEAD),
other=0.0,
)
else:
if EVEN_D:
v = tl.load(v_ptrs + start_n * stride_vt)
else:
v = tl.load(v_ptrs + start_n * stride_vt,
mask=offs_d[None, :] < D_HEAD,
other=0.0)
acc += tl.dot(p, v)
return acc, l_i, m_i
@triton.heuristics({
"M_LT_N":
lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
})
@triton.jit
def _fwd_kernel_batch_inference(
Q,
K,
V,
Out,
sm_scale,
q_batch_starts,
q_batch_ends,
k_batch_starts,
k_batch_ends,
q_batch_ids,
q_start_sids,
stride_qb,
stride_qt,
stride_qh,
stride_qd,
stride_kb,
stride_kt,
stride_kh,
stride_kd,
stride_vb,
stride_vt,
stride_vh,
stride_vd,
stride_ob,
stride_ot,
stride_oh,
stride_od,
layout_crow_ptr,
layout_col_ptr,
layout_crow_stride_h,
layout_crow_stride_m,
layout_col_stride_h,
layout_col_stride_m,
q_k_ratio,
HAS_BATCH_DIM: tl.constexpr,
D_HEAD: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
BLOCK_M_LOADING: tl.constexpr,
EVEN_D: tl.constexpr,
M_LT_N: tl.constexpr,
):
"""
NOTATION:
pid: position id
sid: storage id
sbid: storage block id
pbid: position block id
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
TODO(linxihui):
Optimize grouped-attn
"""
off_zm = tl.program_id(0)
off_h = tl.program_id(1)
off_h_for_kv = off_h // q_k_ratio
if HAS_BATCH_DIM:
off_z = tl.program_id(2)
Q += off_z * stride_qb
K += off_z * stride_kb
V += off_z * stride_vb
Out += off_z * stride_ob
start_m = off_zm
q_start_sid = start_m * BLOCK_M # always 0 for decoding
else:
off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
q_start_sid = tl.load(q_start_sids + off_zm)
start_m = q_start_sid // BLOCK_M # q_sbid
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_D)
q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
past_len = k_seqlen - q_seqlen
Q += q_cu_start * stride_qt + off_h * stride_qh
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
Out += q_cu_start * stride_ot + off_h * stride_oh
q_pbid = (past_len + q_start_sid) // BLOCK_M
if EVEN_D:
q = tl.load(
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
mask=offs_m[:, None] < q_seqlen,
other=0.0,
)
else:
q = tl.load(
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
other=0.0,
)
sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
q_pbid * layout_crow_stride_m)
# TODO(linxihui): load at once, with any Triton version
# that supports `tl.split`, e.g., Triton 3.0
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
sm_scale *= (
1.44269504 # 1/log2 as we use base2 for exponential and logarithm
)
for k_block_col_idx in range(k_block_start, k_block_end - 1):
acc, l_i, m_i = _fwd_kernel_inner(
acc,
l_i,
m_i,
q,
Q,
k_block_col_idx,
layout_col_ptr,
layout_col_stride_h,
layout_col_stride_m,
k_ptrs,
v_ptrs,
off_h,
offs_m,
offs_n,
offs_d,
stride_kt,
stride_vt,
sm_scale,
k_seqlen,
past_len,
False,
BLOCK_M_LOADING,
BLOCK_N,
D_HEAD,
EVEN_D,
M_LT_N,
)
acc, l_i, m_i = _fwd_kernel_inner(
acc,
l_i,
m_i,
q,
Q,
k_block_end - 1,
layout_col_ptr,
layout_col_stride_h,
layout_col_stride_m,
k_ptrs,
v_ptrs,
off_h,
offs_m,
offs_n,
offs_d,
stride_kt,
stride_vt,
sm_scale,
k_seqlen,
past_len,
True,
BLOCK_M_LOADING,
BLOCK_N,
D_HEAD,
EVEN_D,
M_LT_N,
)
# flash-attn 2
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
# write output
if EVEN_D:
tl.store(
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
acc,
mask=offs_m[:, None] < q_seqlen,
)
else:
tl.store(
Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
acc,
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import torch
from vllm.platforms import current_platform
from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask)
IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)
if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
class LocalStridedBlockSparseAttn(torch.nn.Module):
def __init__(
self,
n_heads,
max_seqlen,
local_blocks,
vert_stride,
block_size,
device=None,
dtype=None,
homo_head=False,
active_head_range=None,
q_block_size=None,
use_spda=None,
):
super().__init__()
if use_spda is None:
use_spda = current_platform.is_rocm() or \
current_platform.is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE
device = device or (torch.cuda.current_device()
if current_platform.is_cuda_alike() else "cpu")
device = torch.device(device)
# NOTE: vllm CPU backend support BF16 instead of FP16.
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
or device.type == "cpu" else torch.half)
self.n_heads = n_heads
self.max_seqlen = max_seqlen
self.local_blocks = local_blocks
self.vert_stride = vert_stride
self.use_spda = use_spda
self.dtype = dtype
self.device = device
self.block_size = block_size
self.q_block_size = q_block_size
self.homo_head = homo_head
self.active_head_range = active_head_range
self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride,
homo_head)
sparse_layout, sparse_pattern, self.dense_attn_mask = (
self.get_attn_pattern(dtype, device))
if q_block_size is not None and q_block_size != block_size:
if q_block_size > block_size:
assert q_block_size % block_size == 0
blocks_to_merge = q_block_size // block_size
shape = sparse_pattern.shape
sparse_pattern = sparse_pattern.view(shape[0], -1,
blocks_to_merge,
shape[-1])
sparse_pattern = sparse_pattern.sum(2)
sparse_layout = dense_to_crow_col(sparse_pattern)
else:
raise ValueError(
"Does not support smaller q_block_size. It will be slower."
)
self.sparse_layout = sparse_layout
def get_attn_pattern(self, dtype, device):
sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask(
self.n_heads,
self.max_seqlen,
self.max_seqlen,
dtype,
device,
block_size=self.block_size,
local_blocks=self.local_blocks,
vert_stride=self.vert_stride,
homo_head=self.homo_head,
return_dense=self.use_spda,
dense_mask_type="bias",
)
if (not self.homo_head) and (self.active_head_range is not None):
assert isinstance(self.active_head_range, tuple)
assert (len(self.active_head_range) == 2)
h_start, h_end = self.active_head_range
sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout)
if self.use_spda:
dense_attn_mask = dense_attn_mask[h_start:h_end]
return sparse_layout, sparse_pattern, dense_attn_mask
def varlen_attn(self,
q,
k,
v,
cu_seqlens_k,
cu_seqlens_q=None,
sm_scale=None):
"""
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
Support grouped attention, with `q[:, i*r:(i*r + r)]`
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
cu_seqlens_k: shape=(batch_size + 1,),
indicating segment of samples,
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
cu_seqlens_q: shape=(batch_size + 1, ).
Default None: same as cu_seqlens_k for prefilling or
[0, 1, .., batch_size] for decoding.
The only case you need to specify is when q is a mix of
prefilling and decoding.
sm_scale: softmax scale, default to 1/sqrt(head_size).
return: tensor of shape as q.
"""
assert (
IS_COMPUTE_8_OR_ABOVE
), "Requires compute capability of 8 or above (Ampere or newer) to use \
Triton kernel."
sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
return blocksparse_flash_attn_varlen_fwd(
q,
k,
v,
cu_seqlens_k,
cu_seqlens_q,
sm_scale,
self.sparse_layout,
block_size=self.block_size,
q_block_size=self.q_block_size,
max_seqlen=self.max_seqlen,
)
@staticmethod
def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1):
"""
:param x: (total_tokens, n_heads, head_size)
:return: (batch, n_heads, length, head_size)
"""
x_padded = x.new_empty(
len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2))
cu_seqlens = cu_seqlens.cpu()
for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0,
1).unsqueeze(1))
return x_padded.flatten(1, 2)
@staticmethod
def transpose_and_unpad(x_padded, cu_seqlens):
"""
:param x_padded: (batch, n_heads, length, head_size)
:return: (total_tokens, n_heads, head_size)
"""
cu_seqlens = cu_seqlens.cpu()
total_n_tokens = cu_seqlens[-1]
x = x_padded.new_empty(total_n_tokens, x_padded.size(1),
x_padded.size(3))
for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1))
return x
def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
"""For CPU, V100 or other older GPUs.
NOTE: torch SPDA supports nested tensor,
but seems extremely slow. Choose to pad instead.
"""
assert (cu_seqlens_q is None or
(cu_seqlens_q
== cu_seqlens_k).all()), "Can only handle prompt with SPDA."
assert q.size(0) == k.size(0), "can only handle prompt with SPDA."
assert q.size(1) % k.size(1) == 0
q_k_ratio = q.size(1) // k.size(1)
sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
cu_seqlens = cu_seqlens_k.cpu()
maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
if (self.dense_attn_mask.dtype != q.dtype
or self.dense_attn_mask.device != q.device):
_, _, self.dense_attn_mask = self.get_attn_pattern(
q.dtype, q.device)
attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]
q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
k2, v2 = (self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
for x in [k, v])
spda_output = torch.nn.functional.scaled_dot_product_attention(
q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
return self.transpose_and_unpad(spda_output, cu_seqlens)
def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
"""Dispatch to `varlen_attn` (Ampere or newer) or
`self.spda`(cpu, Volta, Turing or older)based on
the type of device used and cuda compute capability.
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
Support grouped attention, with `q[:, i*r:(i*r + r)]`
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
cu_seqlens_q: shape=(batch_size + 1, ).
Default None: same as cu_seqlens_k for prefilling or
[0, 1, .., batch_size] for decoding.
The only case you need to specify
is when q is a mix of prefilling
and decoding.
sm_scale: softmax scale, default to 1/sqrt(head_size).
return: tensor of shape as q.
"""
assert k.dim() == 3
if self.use_spda:
return self.spda(
q,
k,
v,
cu_seqlens_k,
cu_seqlens_q=cu_seqlens_q,
sm_scale=sm_scale,
)
return self.varlen_attn(q,
k,
v,
cu_seqlens_k,
cu_seqlens_q=cu_seqlens_q,
sm_scale=sm_scale)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Helper functions for 3D sparse pattern
# These function are not optimized and very inefficient.
# Avoid calling them too frequent or use a cache mechanism.
from functools import lru_cache
import numpy as np
import torch
from vllm.triton_utils import triton
class csr_matrix:
"""Simple implementation of CSR matrix conversion without scipy.
This replaced scipy.sparse.csr_matrix() previously used."""
def __init__(self, input_array):
if not isinstance(input_array, np.ndarray):
raise ValueError("Input must be a NumPy array")
self.shape = input_array.shape
rows, cols = self.shape
data = []
indices = []
indptr = [0]
for i in range(rows):
for j in range(cols):
if input_array[i, j]:
data.append(input_array[i, j])
indices.append(j)
indptr.append(len(indices))
self.data = np.array(data)
self.indices = np.array(indices)
self.indptr = np.array(indptr)
def dense_to_crow_col(x: torch.Tensor):
"""Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
NOTE: col_indices padded -1
"""
device = x.device
pad = -1
dim = x.dim()
assert x.dim() in (2, 3)
if x.dim() == 2:
x = x[None]
x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
cols = [torch.from_numpy(xi.indices) for xi in x]
max_cols = max(len(xi) for xi in cols)
cols = [
torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
for xi in cols
]
cols = torch.vstack(cols)
if dim == 2:
crows = crows[0]
cols = cols[0]
return crows.to(device), cols.to(device)
def crow_col_to_dense(crows: torch.Tensor,
cols: torch.Tensor,
dtype: torch.dtype = torch.float16):
dim = crows.dim()
if dim == 1:
crows = crows[None]
cols = cols[None]
device = crows.device
crows, cols = crows.cpu(), cols.cpu() # faster in cpu
shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
x = torch.zeros(shape, dtype=dtype)
for i in range(shape[0]):
for j in range(shape[1]):
x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
if dim == 1:
x = x[0]
return x.to(device)
def dense_to_ccol_row(x: torch.Tensor):
"""Similar, but to CSC format"""
x = x.transpose(-2, -1)
return dense_to_crow_col(x)
def ccol_row_to_dense(ccol: torch.Tensor,
rows: torch.Tensor,
dtype: torch.dtype = torch.float16):
return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
def _get_sparse_attn_mask_homo_head(
q_len: int,
max_seqlen: int,
dtype: torch.dtype,
device: torch.device,
block_size: int = 128,
local_blocks: int = 4,
vert_stride: int = 4,
return_dense: bool = False,
):
"""
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
of CSR format.
- block dense mask
- all token dense mask (be aware that it can be
OOM if it is too big) if `return_dense==True`,
otherwise, None
"""
with torch.no_grad():
num_blocks = triton.cdiv(max_seqlen, block_size)
q_pos = torch.arange(num_blocks)[:, None]
k_pos = torch.arange(num_blocks)[None]
mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
block_mask_dense = (((q_pos >= k_pos)
& ((q_pos - k_pos < local_blocks)
| mask_vert_strided)).to(device).to(dtype))
num_blocks_q = triton.cdiv(q_len, block_size)
block_mask_dense_output = (dense_to_crow_col(
block_mask_dense[-num_blocks_q:].contiguous()))
if return_dense:
mask_dense = torch.kron(
block_mask_dense,
block_mask_dense.new_ones((block_size, block_size)),
)
causal_mask = torch.tril(torch.ones(
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
return (
block_mask_dense_output,
block_mask_dense,
mask_dense,
)
else:
return (
block_mask_dense_output,
block_mask_dense,
None,
)
def binary_mask_to_bias(mask_dense: torch.Tensor):
mask_dense = 1 - mask_dense
mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
return mask_dense
def get_head_sliding_step(n_heads: int,
vert_stride: int,
homo_head: bool = False):
if homo_head:
return 0
return max(1, int(vert_stride / n_heads))
@lru_cache
def get_sparse_attn_mask(
n_heads: int,
q_len: int,
max_seqlen: int,
dtype: torch.dtype,
device: torch.device,
block_size: int = 64,
local_blocks: int = 4,
vert_stride: int = 4,
homo_head: bool = True,
return_dense: bool = False,
dense_mask_type: str = "binary",
):
"""
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
or "bias" (-inf for skip token, 0 or others)
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
of CSR format.
- block dense mask
- all token dense mask (be aware that it can be OOM if it
is too big) if `return_dense==True`, otherwise, None
"""
assert dense_mask_type in ("binary", "bias")
if homo_head:
with torch.no_grad():
(crow, col), block_mask_dense, mask_dense = (
_get_sparse_attn_mask_homo_head(
q_len,
max_seqlen,
dtype,
device,
block_size,
local_blocks,
vert_stride,
return_dense,
))
crow = crow[None].expand(n_heads, crow.shape[0])
col = col[None].expand(n_heads, col.shape[0])
if return_dense:
mask_dense = mask_dense[None].expand(n_heads,
*mask_dense.shape)
if dense_mask_type == "bias":
mask_dense = binary_mask_to_bias(mask_dense)
return (crow, col), block_mask_dense, mask_dense
with torch.no_grad():
num_blocks = triton.cdiv(max_seqlen, block_size)
q_pos = torch.arange(num_blocks)[None, :, None]
k_pos = torch.arange(num_blocks)[None, None]
head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
mask_vert_strided = [
(torch.arange(num_blocks) + h * head_sliding_step + 1) %
vert_stride == 0 for h in range(n_heads)
]
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
block_mask_dense = (((q_pos >= k_pos)
& ((q_pos - k_pos < local_blocks)
| mask_vert_strided)).to(device).to(dtype))
num_blocks_q = triton.cdiv(q_len, block_size)
block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
if return_dense:
mask_dense = torch.kron(
block_mask_dense,
block_mask_dense.new_ones((block_size, block_size)),
)
causal_mask = torch.tril(torch.ones(
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
if dense_mask_type == "bias":
mask_dense = binary_mask_to_bias(mask_dense)
return (
dense_to_crow_col(block_mask_dense_output),
block_mask_dense,
mask_dense,
)
else:
return (
dense_to_crow_col(block_mask_dense_output),
block_mask_dense,
None,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
from vllm_hpu_extension import cache_ops, ops
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
@dataclass
class HPUPagedAttentionMetadata:
"""Metadata for PagedAttention."""
block_list: Optional[torch.Tensor]
block_mapping: Optional[torch.Tensor]
block_usage: Optional[torch.Tensor]
block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
class HPUPagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor, kv_cache_dtype: str,
is_prompt: bool) -> None:
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, is_prompt)
@staticmethod
def forward_decode(**kwargs) -> torch.Tensor:
return ops.flat_pa(**kwargs)
@staticmethod
def swap_blocks(
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
src_to_dsts: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dsts)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dsts)
@staticmethod
def copy_blocks(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dsts: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List, Optional, Tuple
try:
import intel_extension_for_pytorch.llm.modules as ipex_modules
_use_ipex = True
# AttributeError is to handle a bug in ipex https://github.com/intel/intel-extension-for-pytorch/pull/813
except (ImportError, AttributeError):
_use_ipex = False
import torch
from vllm import _custom_ops as ops
class _PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 80, 96, 112, 128, 192, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
*args,
) -> Tuple[int, ...]:
return 2, num_blocks, block_size * num_kv_heads * head_size
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
@staticmethod
def forward_decode(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
tp_rank: int = 0
blocksparse_local_blocks: int = 0
blocksparse_vert_stride: int = 0
blocksparse_block_size: int = 64
blocksparse_head_sliding_step: int = 0
block_size = value_cache.shape[3]
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
class _IPEXPagedAttention(_PagedAttention):
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> Tuple[torch.Tensor, torch.Tensor]:
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
ipex_modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache,
slot_mapping.flatten().int())
@staticmethod
def forward_decode(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
block_size = value_cache.shape[2]
head_mapping = torch.arange(
0,
num_kv_heads,
device="cpu",
dtype=torch.int32,
).view(num_kv_heads,
1).repeat_interleave(query.size(1) // num_kv_heads).flatten()
ipex_modules.PagedAttention.single_query_cached_kv_attention(
output, query.contiguous(), key_cache, value_cache, head_mapping,
scale, block_tables, context_lens, block_size, max_context_len,
alibi_slopes)
PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention
...@@ -6,7 +6,7 @@ from typing import Optional ...@@ -6,7 +6,7 @@ from typing import Optional
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
def get_aiter_mla_metadata(max_batch_size: int, block_size: int, def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
...@@ -93,8 +93,12 @@ def mla_decode_fwd_fake( ...@@ -93,8 +93,12 @@ def mla_decode_fwd_fake(
if current_platform.is_rocm(): if current_platform.is_rocm():
if is_torch_equal_or_newer("2.7.0"):
tags = ()
else:
tags = (torch.Tag.needs_fixed_stride_order, ),
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd", direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
op_func=mla_decode_fwd_impl, op_func=mla_decode_fwd_impl,
mutates_args=["o"], mutates_args=["o"],
fake_impl=mla_decode_fwd_fake, fake_impl=mla_decode_fwd_fake,
tags=[torch.Tag.needs_fixed_stride_order]) tags=tags)
...@@ -8,10 +8,9 @@ ...@@ -8,10 +8,9 @@
# - Thomas Parnell <tpa@zurich.ibm.com> # - Thomas Parnell <tpa@zurich.ibm.com>
import torch import torch
import triton
import triton.language as tl
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -145,7 +144,19 @@ def kernel_unified_attention_2d( ...@@ -145,7 +144,19 @@ def kernel_unified_attention_2d(
mask=query_mask_1, mask=query_mask_1,
other=0.0) other=0.0)
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) # compute the length of the longest sequence prefix spanned by any
# query token in the current q_block (q_block_local_idx)
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
BLOCK_M - 1) // num_queries_per_kv + 1
# adjust for potential padding in the last q_block by considering the
# actual sequence length
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
# calculate the number of tiles (blocks) that need to be processed to
# cover the longest sequence prefix (due to causal masking, blocks beyond
# this prefix can be skipped)
num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE)
# iterate through tiles # iterate through tiles
for j in range(0, num_blocks): for j in range(0, num_blocks):
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from functools import cache from functools import cache
from typing import Generator, Optional, Union from typing import Generator, Optional, Union
...@@ -79,31 +80,61 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: ...@@ -79,31 +80,61 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return forced_attn_backend return forced_attn_backend
def supports_head_size( @dataclass(frozen=True)
class _IsSupported:
can_import: bool
head_size: bool
dtype: bool
def __bool__(self) -> bool:
return self.can_import and self.head_size and self.dtype
def is_attn_backend_supported(
attn_backend: Union[str, type[AttentionBackend]], attn_backend: Union[str, type[AttentionBackend]],
head_size: int, head_size: int,
) -> bool: dtype: torch.dtype,
*,
allow_import_error: bool = True,
) -> _IsSupported:
if isinstance(attn_backend, str): if isinstance(attn_backend, str):
try: try:
attn_backend = resolve_obj_by_qualname(attn_backend) attn_backend = resolve_obj_by_qualname(attn_backend)
except ImportError: except ImportError:
return False if not allow_import_error:
raise
return _IsSupported(can_import=False, head_size=False, dtype=False)
assert isinstance(attn_backend, type) assert isinstance(attn_backend, type)
# TODO: Update the interface once V0 is removed # TODO: Update the interface once V0 is removed
if get_supported_head_sizes := getattr(attn_backend, if get_supported_head_sizes := getattr(attn_backend,
"get_supported_head_sizes", None): "get_supported_head_sizes", None):
return head_size in get_supported_head_sizes() is_head_size_supported = head_size in get_supported_head_sizes()
if validate_head_size := getattr(attn_backend, "validate_head_size", None): elif validate_head_size := getattr(attn_backend, "validate_head_size",
None):
try: try:
validate_head_size(head_size) validate_head_size(head_size)
return True is_head_size_supported = True
except Exception: except Exception:
return False is_head_size_supported = False
else:
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"head size validation")
raise NotImplementedError(f"{attn_backend.__name__} does not support " if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes",
"head size validation") None):
is_dtype_supported = dtype in get_supported_dtypes()
else:
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"dtype validation")
return _IsSupported(
can_import=True,
head_size=is_head_size_supported,
dtype=is_dtype_supported,
)
def get_attn_backend( def get_attn_backend(
...@@ -112,7 +143,6 @@ def get_attn_backend( ...@@ -112,7 +143,6 @@ def get_attn_backend(
kv_cache_dtype: Optional[str], kv_cache_dtype: Optional[str],
block_size: int, block_size: int,
is_attention_free: bool, is_attention_free: bool,
is_blocksparse: bool = False,
use_mla: bool = False, use_mla: bool = False,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it.""" """Selects which attention backend to use and lazily imports it."""
...@@ -126,7 +156,6 @@ def get_attn_backend( ...@@ -126,7 +156,6 @@ def get_attn_backend(
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
block_size=block_size, block_size=block_size,
is_attention_free=is_attention_free, is_attention_free=is_attention_free,
is_blocksparse=is_blocksparse,
use_v1=envs.VLLM_USE_V1, use_v1=envs.VLLM_USE_V1,
use_mla=use_mla, use_mla=use_mla,
) )
...@@ -139,16 +168,9 @@ def _cached_get_attn_backend( ...@@ -139,16 +168,9 @@ def _cached_get_attn_backend(
kv_cache_dtype: Optional[str], kv_cache_dtype: Optional[str],
block_size: int, block_size: int,
is_attention_free: bool, is_attention_free: bool,
is_blocksparse: bool = False,
use_v1: bool = False, use_v1: bool = False,
use_mla: bool = False, use_mla: bool = False,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import (
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend
# If there are no attention layers (e.g. we are running Mamba), # If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION # use the placeholder NO_ATTENTION
if is_attention_free: if is_attention_free:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context):
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
f"is not valid: target layer {target_layer_name} ")
if current_layer_name == target_layer_name:
raise ValueError(error_msg +
"cannot be the same as the current layer.")
if target_layer_name not in static_forward_context:
from vllm.model_executor.models.utils import extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx = extract_layer_index(current_layer_name)
target_layer_idx = extract_layer_index(target_layer_name)
if current_layer_idx <= target_layer_idx:
raise ValueError(error_msg + "must come before the current layer.")
else:
raise ValueError(error_msg +
"is not a valid Attention layer in the model.")
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type = static_forward_context[
target_layer_name].attn_type
expected = static_forward_context[current_layer_name].attn_type
if target_layer_attn_type != expected:
raise ValueError(
error_msg +
f"must be the same type as the current layer ({expected}).")
...@@ -481,6 +481,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ...@@ -481,6 +481,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"],
help="Name of the dataset to benchmark on.", help="Name of the dataset to benchmark on.",
) )
parser.add_argument(
"--no-stream",
action="store_true",
help="Do not load the dataset in streaming mode.",
)
parser.add_argument( parser.add_argument(
"--dataset-path", "--dataset-path",
type=str, type=str,
...@@ -649,6 +654,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ...@@ -649,6 +654,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
dataset_class = ASRDataset dataset_class = ASRDataset
args.hf_split = "train" args.hf_split = "train"
elif args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS:
dataset_class = MLPerfDataset
args.hf_split = "train"
else: else:
supported_datasets = set([ supported_datasets = set([
dataset_name for cls in HuggingFaceDataset.__subclasses__() dataset_name for cls in HuggingFaceDataset.__subclasses__()
...@@ -674,6 +682,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ...@@ -674,6 +682,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
dataset_subset=args.hf_subset, dataset_subset=args.hf_subset,
dataset_split=args.hf_split, dataset_split=args.hf_split,
random_seed=args.seed, random_seed=args.seed,
no_stream=args.no_stream,
).sample( ).sample(
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -971,6 +980,7 @@ class HuggingFaceDataset(BenchmarkDataset): ...@@ -971,6 +980,7 @@ class HuggingFaceDataset(BenchmarkDataset):
self, self,
dataset_path: str, dataset_path: str,
dataset_split: str, dataset_split: str,
no_stream: bool = False,
dataset_subset: Optional[str] = None, dataset_subset: Optional[str] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
...@@ -978,6 +988,7 @@ class HuggingFaceDataset(BenchmarkDataset): ...@@ -978,6 +988,7 @@ class HuggingFaceDataset(BenchmarkDataset):
self.dataset_split = dataset_split self.dataset_split = dataset_split
self.dataset_subset = dataset_subset self.dataset_subset = dataset_subset
self.load_stream = not no_stream
self.load_data() self.load_data()
def load_data(self) -> None: def load_data(self) -> None:
...@@ -986,7 +997,7 @@ class HuggingFaceDataset(BenchmarkDataset): ...@@ -986,7 +997,7 @@ class HuggingFaceDataset(BenchmarkDataset):
self.dataset_path, self.dataset_path,
name=self.dataset_subset, name=self.dataset_subset,
split=self.dataset_split, split=self.dataset_split,
streaming=True, streaming=self.load_stream,
) )
self.data = self.data.shuffle(seed=self.random_seed) self.data = self.data.shuffle(seed=self.random_seed)
...@@ -1439,3 +1450,82 @@ class ASRDataset(HuggingFaceDataset): ...@@ -1439,3 +1450,82 @@ class ASRDataset(HuggingFaceDataset):
) )
self.maybe_oversample_requests(sampled_requests, num_requests) self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests return sampled_requests
# -----------------------------------------------------------------------------
# MLPerf Dataset Implementation
# -----------------------------------------------------------------------------
class MLPerfDataset(HuggingFaceDataset):
"""
MLPerf Inference Dataset.
Dataset on HF:
https://huggingface.co/datasets/mgoin/mlperf-inference-llama2-data
https://huggingface.co/datasets/mgoin/mlperf-inference-llama3.1-data
Each record contains:
- "system_prompt": system role instruction.
- "question": user question.
- "output": reference answer.
We combine the system prompt and question into a chat-formatted prompt
(using the tokenizer's chat template) and set the expected output length to
the tokenized length of the provided reference answer.
"""
SUPPORTED_DATASET_PATHS = {
"mgoin/mlperf-inference-llama2-data",
"mgoin/mlperf-inference-llama3.1-data",
}
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
**kwargs,
) -> list[SampleRequest]:
# Force dynamic output length based on reference completion.
dynamic_output = output_len is None
sampled_requests: list[SampleRequest] = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
system_prompt = item["system_prompt"]
question = item["question"]
reference_answer = item["output"]
# Build chat-style prompt using tokenizer template, if available.
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
]
prompt_formatted = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
prompt_len = len(tokenizer(prompt_formatted).input_ids)
# Determine output length from reference answer tokens.
ref_out_len = len(
tokenizer(reference_answer, add_special_tokens=False).input_ids
)
expected_output_len = ref_out_len if dynamic_output else output_len
# Validate sequence lengths.
if not is_valid_sequence(prompt_len, expected_output_len):
continue
sampled_requests.append(
SampleRequest(
prompt=prompt_formatted,
prompt_len=prompt_len,
expected_output_len=expected_output_len,
)
)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
...@@ -138,31 +138,54 @@ async def get_request( ...@@ -138,31 +138,54 @@ async def get_request(
input_requests = list(input_requests) input_requests = list(input_requests)
total_requests = len(input_requests) total_requests = len(input_requests)
request_index = 0 assert total_requests > 0, "No requests provided."
for request in input_requests: # Precompute delays among requests to minimize request send laggings
request_rates = []
delay_ts = []
for request_index, request in enumerate(input_requests):
current_request_rate = _get_current_request_rate(ramp_up_strategy, current_request_rate = _get_current_request_rate(ramp_up_strategy,
ramp_up_start_rps, ramp_up_start_rps,
ramp_up_end_rps, ramp_up_end_rps,
request_index, request_index,
total_requests, total_requests,
request_rate) request_rate)
request_rates.append(current_request_rate)
yield request, current_request_rate
request_index += 1
if current_request_rate == float("inf"): if current_request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait. delay_ts.append(0)
continue else:
theta = 1.0 / (current_request_rate * burstiness)
theta = 1.0 / (current_request_rate * burstiness)
# Sample the request interval from the gamma distribution.
# Sample the request interval from the gamma distribution. # If burstiness is 1, it follows exponential distribution.
# If burstiness is 1, it follows exponential distribution. delay_ts.append(np.random.gamma(shape=burstiness, scale=theta))
interval = np.random.gamma(shape=burstiness, scale=theta)
# The next request will be sent after the interval. # Calculate the cumulative delay time from the first sent out requests.
await asyncio.sleep(interval) for i in range(1, len(delay_ts)):
delay_ts[i] += delay_ts[i - 1]
if ramp_up_strategy is None and delay_ts[-1] != 0:
# When ramp_up_strategy is not set, we assume the request rate is fixed
# and all requests should be sent in target_total_delay_s, the following
# logic would re-scale delay time to ensure the final delay_ts
# align with target_total_delay_s.
#
# NOTE: If we simply accumulate the random delta values
# from the gamma distribution, their sum would have 1-2% gap
# from target_total_delay_s. The purpose of the following logic is to
# close the gap for stablizing the throughput data
# from different random seeds.
target_total_delay_s = total_requests / request_rate
normalize_factor = target_total_delay_s / delay_ts[-1]
delay_ts = [delay * normalize_factor for delay in delay_ts]
start_ts = time.time()
for request_index, request in enumerate(input_requests):
if delay_ts[request_index] > 0:
current_ts = time.time()
sleep_interval_s = start_ts + delay_ts[request_index] - current_ts
if sleep_interval_s > 0:
await asyncio.sleep(sleep_interval_s)
yield request, request_rates[request_index]
def calculate_metrics( def calculate_metrics(
......
...@@ -96,25 +96,30 @@ DEFAULT_PIP_PATTERNS = { ...@@ -96,25 +96,30 @@ DEFAULT_PIP_PATTERNS = {
def run(command): def run(command):
"""Return (return-code, stdout, stderr).""" """Return (return-code, stdout, stderr)."""
shell = True if type(command) is str else False shell = True if type(command) is str else False
p = subprocess.Popen(command, try:
stdout=subprocess.PIPE, p = subprocess.Popen(command,
stderr=subprocess.PIPE, stdout=subprocess.PIPE,
shell=shell) stderr=subprocess.PIPE,
raw_output, raw_err = p.communicate() shell=shell)
rc = p.returncode raw_output, raw_err = p.communicate()
if get_platform() == 'win32': rc = p.returncode
enc = 'oem' if get_platform() == 'win32':
else: enc = 'oem'
enc = locale.getpreferredencoding() else:
output = raw_output.decode(enc) enc = locale.getpreferredencoding()
if command == 'nvidia-smi topo -m': output = raw_output.decode(enc)
# don't remove the leading whitespace of `nvidia-smi topo -m` if command == 'nvidia-smi topo -m':
# because they are meaningful # don't remove the leading whitespace of `nvidia-smi topo -m`
output = output.rstrip() # because they are meaningful
else: output = output.rstrip()
output = output.strip() else:
err = raw_err.decode(enc) output = output.strip()
return rc, output, err.strip() err = raw_err.decode(enc)
return rc, output, err.strip()
except FileNotFoundError:
cmd_str = command if isinstance(command, str) else command[0]
return 127, '', f"Command not found: {cmd_str}"
def run_and_read_all(run_lambda, command): def run_and_read_all(run_lambda, command):
...@@ -148,7 +153,7 @@ def get_conda_packages(run_lambda, patterns=None): ...@@ -148,7 +153,7 @@ def get_conda_packages(run_lambda, patterns=None):
if patterns is None: if patterns is None:
patterns = DEFAULT_CONDA_PATTERNS patterns = DEFAULT_CONDA_PATTERNS
conda = os.environ.get('CONDA_EXE', 'conda') conda = os.environ.get('CONDA_EXE', 'conda')
out = run_and_read_all(run_lambda, "{} list".format(conda)) out = run_and_read_all(run_lambda, [conda, 'list'])
if out is None: if out is None:
return out return out
......
...@@ -120,10 +120,15 @@ class CompilerManager: ...@@ -120,10 +120,15 @@ class CompilerManager:
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
compiled_graph = self.compiler.load(handle, graph, example_inputs, compiled_graph = self.compiler.load(handle, graph, example_inputs,
graph_index, runtime_shape) graph_index, runtime_shape)
logger.debug( if runtime_shape is None:
"Directly load the %s-th graph for shape %s from %s via " logger.debug(
"handle %s", graph_index, str(runtime_shape), self.compiler.name, "Directly load the %s-th graph for dynamic shape from %s via "
handle) "handle %s", graph_index, self.compiler.name, handle)
else:
logger.debug(
"Directly load the %s-th graph for shape %s from %s via "
"handle %s", graph_index, str(runtime_shape),
self.compiler.name, handle)
return compiled_graph return compiled_graph
def compile(self, def compile(self,
...@@ -152,9 +157,15 @@ class CompilerManager: ...@@ -152,9 +157,15 @@ class CompilerManager:
# there can be multiple graphs due to piecewise compilation. # there can be multiple graphs due to piecewise compilation.
now = time.time() now = time.time()
elapsed = now - compilation_start_time elapsed = now - compilation_start_time
logger.info( if runtime_shape is None:
"Directly load the compiled graph(s) for shape %s " logger.info(
"from the cache, took %.3f s", str(runtime_shape), elapsed) "Directly load the compiled graph(s) for dynamic shape "
"from the cache, took %.3f s", elapsed)
else:
logger.info(
"Directly load the compiled graph(s) for shape %s "
"from the cache, took %.3f s", str(runtime_shape),
elapsed)
return compiled_graph return compiled_graph
# no compiler cached the graph, or the cache is disabled, # no compiler cached the graph, or the cache is disabled,
...@@ -172,17 +183,28 @@ class CompilerManager: ...@@ -172,17 +183,28 @@ class CompilerManager:
assert compiled_graph is not None, "Failed to compile the graph" assert compiled_graph is not None, "Failed to compile the graph"
# store the artifact in the cache # store the artifact in the cache
if handle is not None: if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None:
self.cache[(runtime_shape, graph_index, self.cache[(runtime_shape, graph_index,
self.compiler.name)] = handle self.compiler.name)] = handle
compilation_counter.num_cache_entries_updated += 1
self.is_cache_updated = True self.is_cache_updated = True
if graph_index == 0: if graph_index == 0:
# adds some info logging for the first graph # adds some info logging for the first graph
logger.info("Cache the graph of shape %s for later use", if runtime_shape is None:
str(runtime_shape)) logger.info(
logger.debug( "Cache the graph for dynamic shape for later use")
"store the %s-th graph for shape %s from %s via handle %s", else:
graph_index, str(runtime_shape), self.compiler.name, handle) logger.info("Cache the graph of shape %s for later use",
str(runtime_shape))
if runtime_shape is None:
logger.debug(
"Store the %s-th graph for dynamic shape from %s via "
"handle %s", graph_index, self.compiler.name, handle)
else:
logger.debug(
"Store the %s-th graph for shape %s from %s via handle %s",
graph_index, str(runtime_shape), self.compiler.name,
handle)
# after compiling the last graph, record the end time # after compiling the last graph, record the end time
if graph_index == num_graphs - 1: if graph_index == num_graphs - 1:
...@@ -190,7 +212,7 @@ class CompilerManager: ...@@ -190,7 +212,7 @@ class CompilerManager:
elapsed = now - compilation_start_time elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed compilation_config.compilation_time += elapsed
if runtime_shape is None: if runtime_shape is None:
logger.info("Compiling a graph for general shape takes %.2f s", logger.info("Compiling a graph for dynamic shape takes %.2f s",
elapsed) elapsed)
else: else:
logger.info("Compiling a graph for shape %s takes %.2f s", logger.info("Compiling a graph for shape %s takes %.2f s",
...@@ -308,7 +330,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): ...@@ -308,7 +330,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
i for i, x in enumerate(args) if isinstance(x, torch.SymInt) i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
] ]
global compilation_start_time global compilation_start_time
compiled_graph_for_general_shape = self.vllm_backend.\ compiled_graph_for_dynamic_shape = self.vllm_backend.\
compiler_manager.compile( compiler_manager.compile(
submod, submod,
args, args,
...@@ -323,7 +345,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): ...@@ -323,7 +345,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self.module.__dict__[target] = piecewise_backend( self.module.__dict__[target] = piecewise_backend(
submod, self.vllm_config, self.graph_pool, index, submod, self.vllm_config, self.graph_pool, index,
len(self.compile_submod_names), sym_shape_indices, len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_general_shape, self.vllm_backend) compiled_graph_for_dynamic_shape, self.vllm_backend)
compilation_counter.num_piecewise_capturable_graphs_seen += 1 compilation_counter.num_piecewise_capturable_graphs_seen += 1
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from importlib.util import find_spec
from typing import Optional from typing import Optional
import torch import torch
import torch._inductor.pattern_matcher as pm import torch._inductor.pattern_matcher as pm
import torch.fx as fx import torch.fx as fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group from torch.distributed._symmetric_memory import enable_symm_mem_for_group
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tp_group from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import direct_register_custom_op
from .vllm_inductor_pass import VllmInductorPass from .vllm_inductor_pass import VllmInductorPass
if find_spec("flashinfer"):
try:
import flashinfer.comm as flashinfer_comm
flashinfer_comm = (flashinfer_comm if hasattr(
flashinfer_comm, "trtllm_allreduce_fusion") else None)
except ImportError:
flashinfer_comm = None
else:
flashinfer_comm = None
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
class BasePattern: class BasePattern:
...@@ -43,7 +61,8 @@ class GEMMReduceScatterPattern(BasePattern): ...@@ -43,7 +61,8 @@ class GEMMReduceScatterPattern(BasePattern):
mm, mm,
dim=0, dim=0,
world_size=self.tp_size, world_size=self.tp_size,
group_name=self.tp.unique_name) group_name=self.tp.unique_name,
)
return reduce_scatter return reduce_scatter
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor): def replacement(mul: torch.Tensor, mm_weight: torch.Tensor):
...@@ -79,7 +98,8 @@ class AllGatherGEMMPattern(BasePattern): ...@@ -79,7 +98,8 @@ class AllGatherGEMMPattern(BasePattern):
x, x,
dim=0, dim=0,
world_size=self.tp_size, world_size=self.tp_size,
group_name=self.tp.unique_name) group_name=self.tp.unique_name,
)
return torch.ops.aten.mm.default(all_gather, weight) return torch.ops.aten.mm.default(all_gather, weight)
...@@ -125,3 +145,343 @@ class AsyncTPPass(VllmInductorPass): ...@@ -125,3 +145,343 @@ class AsyncTPPass(VllmInductorPass):
logger.debug("Replaced %s patterns", count) logger.debug("Replaced %s patterns", count)
self.dump_graph(graph, "after_async_tp_pass") self.dump_graph(graph, "after_async_tp_pass")
self.end_and_log() self.end_and_log()
if flashinfer_comm is not None:
_FI_WORKSPACE_TENSOR = None
MiB = 1024 * 1024
# Max size of the input tensor per world size
# to use flashinfer fused allreduce
_FI_MAX_SIZES = {
2: MiB, # 1MB
4: MiB, # 1MB
6: MiB // 2, # 512KB
8: MiB // 2, # 512KB
}
# opt for a more conservative default value
# when world size is not in _FI_MAX_SIZES
_DEFAULT_FI_MAX_SIZE = MiB // 2
def call_trtllm_fused_allreduce_norm(
allreduce_in: torch.Tensor,
residual: torch.Tensor,
rms_gamma: torch.Tensor,
rms_eps: float,
world_rank: int,
world_size: int,
launch_with_pdl: bool,
trigger_completion_at_end: bool,
fp32_acc: bool,
max_token_num: int,
norm_out: Optional[torch.Tensor] = None,
) -> None:
num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size
max_fusion_size = max_token_num * hidden_size * element_size
use_flashinfer = current_tensor_size <= min(
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
max_fusion_size,
)
if use_flashinfer:
assert (_FI_WORKSPACE_TENSOR is not None
), "Flashinfer must be enabled when using flashinfer"
if norm_out is None:
norm_out = allreduce_in
residual_out = residual
else:
# return residual_out as allreduce_out with zeroed residual_in
# as flashinfer does not support rms_norm
# and allreduce_out together
residual_out = allreduce_in
# For the sizes that are smaller than the max size,
# we only use flashinfer one shot allreduce
flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=allreduce_in,
token_num=allreduce_in.shape[0],
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
world_rank=world_rank,
world_size=world_size,
hidden_dim=allreduce_in.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
launch_with_pdl=launch_with_pdl,
use_oneshot=True,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNorm,
allreduce_out=None,
quant_out=None,
scale_out=None,
layout_code=None,
scale_factor=None,
)
else:
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
if norm_out is None:
torch.ops._C.fused_add_rms_norm(allreduce_out, residual,
rms_gamma, rms_eps)
else:
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma,
rms_eps)
allreduce_in.copy_(allreduce_out)
def call_trtllm_fused_allreduce_norm_fake(
allreduce_in: torch.Tensor,
residual: torch.Tensor,
rms_gamma: torch.Tensor,
rms_eps: float,
world_rank: int,
world_size: int,
launch_with_pdl: bool,
trigger_completion_at_end: bool,
fp32_acc: bool,
max_token_num: int,
norm_out: Optional[torch.Tensor] = None,
) -> None:
pass
direct_register_custom_op(
op_name="flashinfer_trtllm_fused_allreduce_norm",
op_func=call_trtllm_fused_allreduce_norm,
mutates_args=[
"allreduce_in",
"residual",
"norm_out",
],
fake_impl=call_trtllm_fused_allreduce_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
flashinfer_trtllm_fused_allreduce_norm = (
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default)
class FlashInferFusedAllReduceParams:
"""Parameters for FlashInfer fused allreduce operations."""
def __init__(
self,
rank: int,
world_size: int,
use_fp32_lamport: bool = False,
max_token_num: int = 1024,
):
self.rank = rank
self.world_size = world_size
self.use_fp32_lamport = use_fp32_lamport
self.trigger_completion_at_end = True
self.launch_with_pdl = True
self.fp32_acc = True
self.use_oneshot = False
self.max_token_num = max_token_num
def get_trtllm_fused_allreduce_kwargs(self):
return {
"world_rank": self.rank,
"world_size": self.world_size,
"launch_with_pdl": self.launch_with_pdl,
"trigger_completion_at_end": self.trigger_completion_at_end,
"fp32_acc": self.fp32_acc,
"max_token_num": self.max_token_num,
}
class AllReduceRMSNORMPattern(BasePattern):
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str,
allreduce_params: FlashInferFusedAllReduceParams,
):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
def get_inputs(self):
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
rms_result = torch.empty([1, 8, 4],
device=self.device,
dtype=self.dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
return [input, rms_result, weight]
def register(self, pm_pass: PatternMatcherPass):
def pattern(input: torch.Tensor, rms_result: torch.Tensor,
weight: torch.Tensor):
all_reduce_output = tensor_model_parallel_all_reduce(input)
rms = auto_functionalized(
RMS_OP,
result=rms_result,
input=all_reduce_output,
weight=weight,
epsilon=self.epsilon,
)
return rms[1], all_reduce_output
def replacement(input: torch.Tensor, rms_result: torch.Tensor,
weight: torch.Tensor):
residual = torch.zeros_like(input)
allreduce = auto_functionalized(
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
allreduce_in=input,
residual=residual,
norm_out=rms_result,
rms_gamma=weight,
rms_eps=self.epsilon,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
return allreduce[3], allreduce[1]
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class AllReduceFusedAddRMSNormPattern(BasePattern):
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str,
allreduce_params: FlashInferFusedAllReduceParams,
):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
def get_inputs(self):
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [
residual,
input,
weight,
]
def register(self, pm_pass: PatternMatcherPass):
def pattern(residual: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor):
all_reduce_output = tensor_model_parallel_all_reduce(input)
rms = auto_functionalized(
RMS_ADD_OP,
input=all_reduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
return rms[1], rms[2]
def replacement(residual: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor):
allreduce = auto_functionalized(
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
allreduce_in=input,
residual=residual,
rms_gamma=weight,
rms_eps=self.epsilon,
norm_out=None,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
return allreduce[1], allreduce[2]
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class AllReduceFusionPass(VllmInductorPass):
def __init__(self, config: VllmConfig):
super().__init__(config)
self.disabled = True
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size <= 1:
return
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="all_reduce_fusion_pass")
if config.model_config is None:
return
self.hidden_dim = config.model_config.get_hidden_size()
self.group = get_tp_group().device_group
rank = get_tensor_model_parallel_rank()
use_fp32_lamport = self.model_dtype == torch.float32
if flashinfer_comm is None:
logger.warning(
"Flashinfer is not installed or comm module not found, "
"skipping allreduce fusion pass")
return
# Check if the world size is supported
if self.tp_size not in _FI_MAX_SIZES:
logger.warning(
"Flashinfer allreduce fusion is not "
"supported for world size %s",
self.tp_size,
)
return
self.ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank,
tp_size=self.tp_size,
max_token_num=config.compilation_config.pass_config.
fi_allreduce_fusion_max_token_num,
hidden_dim=self.hidden_dim,
group=self.group,
use_fp32_lamport=use_fp32_lamport,
))
global _FI_WORKSPACE_TENSOR
_FI_WORKSPACE_TENSOR = workspace_tensor
self.allreduce_params = FlashInferFusedAllReduceParams(
rank=rank,
world_size=self.tp_size,
use_fp32_lamport=use_fp32_lamport,
max_token_num=config.compilation_config.pass_config.
fi_allreduce_fusion_max_token_num,
)
for epsilon in [1e-5, 1e-6]:
AllReduceRMSNORMPattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormPattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
self.disabled = False
def __call__(self, graph: fx.Graph):
if self.disabled:
return
self.begin()
self.dump_graph(graph, "before_all_reduce_fusion_pass")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", count)
self.dump_graph(graph, "after_all_reduce_fusion_pass")
self.end_and_log()
def __del__(self):
if self.disabled:
return
if flashinfer_comm is not None:
flashinfer_comm.trtllm_destroy_ipc_workspace(
self.ipc_handles, self.group)
...@@ -213,7 +213,9 @@ class InductorStandaloneAdaptor(CompilerInterface): ...@@ -213,7 +213,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
# Save the compiled artifact to disk in the specified path # Save the compiled artifact to disk in the specified path
assert key is not None assert key is not None
path = os.path.join(self.cache_dir, key) path = os.path.join(self.cache_dir, key)
compiled_graph.save(path=path, format="unpacked") if not envs.VLLM_DISABLE_COMPILE_CACHE:
compiled_graph.save(path=path, format="unpacked")
compilation_counter.num_compiled_artifacts_saved += 1
return compiled_graph, (key, path) return compiled_graph, (key, path)
def load(self, def load(self,
...@@ -421,6 +423,12 @@ class InductorAdaptor(CompilerInterface): ...@@ -421,6 +423,12 @@ class InductorAdaptor(CompilerInterface):
if is_torch_equal_or_newer("2.6"): if is_torch_equal_or_newer("2.6"):
stack.enter_context( stack.enter_context(
torch._inductor.config.patch(fx_graph_remote_cache=False)) torch._inductor.config.patch(fx_graph_remote_cache=False))
# InductorAdaptor (unfortunately) requires AOTAutogradCache
# to be turned off to run. It will fail to acquire the hash_str
# and error if not.
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
stack.enter_context(
torch._functorch.config.patch(enable_autograd_cache=False))
stack.enter_context( stack.enter_context(
torch._functorch.config.patch( torch._functorch.config.patch(
enable_remote_autograd_cache=False)) enable_remote_autograd_cache=False))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment