Commit af7f4372 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1

parents 5e19cdef 09c77926
from functools import lru_cache
from pathlib import Path
from typing import Optional
import vllm.envs as envs
from vllm.connections import global_http_connection
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
vLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
def get_cache_dir():
def get_cache_dir() -> Path:
"""Get the path to the cache for storing downloaded assets."""
path = Path(envs.VLLM_ASSETS_CACHE)
path.mkdir(parents=True, exist_ok=True)
return path
@lru_cache
def get_vllm_public_assets(filename: str,
s3_prefix: Optional[str] = None) -> Path:
"""
Download an asset file from ``s3://vllm-public-assets``
and return the path to the downloaded file.
"""
asset_directory = get_cache_dir() / "vllm_public_assets"
asset_directory.mkdir(parents=True, exist_ok=True)
asset_path = asset_directory / filename
if not asset_path.exists():
if s3_prefix is not None:
filename = s3_prefix + "/" + filename
global_http_connection.download_file(
f"{vLLM_S3_BUCKET_URL}/{filename}",
asset_path,
timeout=VLLM_IMAGE_FETCH_TIMEOUT)
return asset_path
from dataclasses import dataclass
from functools import lru_cache
from typing import Literal
import torch
from PIL import Image
from vllm.connections import global_http_connection
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
from vllm.assets.base import get_vllm_public_assets
from .base import get_cache_dir
@lru_cache
def get_air_example_data_2_asset(filename: str) -> Image.Image:
"""
Download and open an image from
``s3://air-example-data-2/vllm_opensource_llava/``.
"""
image_directory = get_cache_dir() / "air-example-data-2"
image_directory.mkdir(parents=True, exist_ok=True)
image_path = image_directory / filename
if not image_path.exists():
base_url = "https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava"
global_http_connection.download_file(f"{base_url}/{filename}",
image_path,
timeout=VLLM_IMAGE_FETCH_TIMEOUT)
return Image.open(image_path)
VLM_IMAGES_DIR = "vision_model_images"
@dataclass(frozen=True)
......@@ -36,4 +15,16 @@ class ImageAsset:
@property
def pil_image(self) -> Image.Image:
return get_air_example_data_2_asset(f"{self.name}.jpg")
image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
s3_prefix=VLM_IMAGES_DIR)
return Image.open(image_path)
@property
def image_embeds(self) -> torch.Tensor:
"""
Image embeddings, only used for testing purposes with llava 1.5.
"""
image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
s3_prefix=VLM_IMAGES_DIR)
return torch.load(image_path)
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder)
AttentionMetadataBuilder,
AttentionState, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
......@@ -8,7 +9,9 @@ __all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"AttentionType",
"AttentionMetadataBuilder",
"Attention",
"AttentionState",
"get_attn_backend",
]
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
......@@ -7,7 +8,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
import torch
if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
ModelRunnerInputBuilderBase)
class AttentionType(Enum):
......@@ -34,6 +37,11 @@ class AttentionBackend(ABC):
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_state_cls() -> Type["AttentionState"]:
raise NotImplementedError
@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)
......@@ -75,6 +83,9 @@ class AttentionBackend(ABC):
) -> None:
raise NotImplementedError
def advance_step(self, num_seqs: int, num_queries: int):
raise NotImplementedError
@dataclass
class AttentionMetadata:
......@@ -123,6 +134,47 @@ class AttentionMetadata:
T = TypeVar("T", bound=AttentionMetadata)
class AttentionState(ABC, Generic[T]):
"""Holds attention backend-specific objects reused during the
lifetime of the model runner."""
@abstractmethod
def __init__(self, runner: "ModelRunnerBase"):
...
@abstractmethod
@contextmanager
def graph_capture(self, max_batch_size: int):
"""Context manager used when capturing CUDA graphs."""
yield
@abstractmethod
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
"""Clone attention state to save in CUDA graph metadata."""
...
@abstractmethod
def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T:
"""Get attention metadata for CUDA graph capture of batch_size."""
...
@abstractmethod
def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]:
"""Get attention-specific input buffers for CUDA graph capture."""
...
@abstractmethod
def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any],
attn_metadata: T) -> None:
"""In-place modify input buffers dict for CUDA graph replay."""
...
@abstractmethod
def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
"""Prepare state for forward pass."""
...
class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""
......
......@@ -5,7 +5,8 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention
......@@ -98,6 +99,10 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
return BlocksparseFlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......
......@@ -3,21 +3,123 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.utils import make_tensor_with_pad
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# custom op does not support tuple input
real_window_size: Tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return _flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
window_size=real_window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
block_table=block_table,
)
@flash_attn_varlen_func.register_fake # type: ignore
def _(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(q)
@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[])
def flash_attn_with_kvcache(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return _flash_attn_with_kvcache(
decode_query,
key_cache,
value_cache,
cache_seqlens=cache_seqlens,
block_table=block_table,
softmax_scale=softmax_scale,
causal=causal,
alibi_slopes=alibi_slopes,
softcap=softcap,
)
@flash_attn_with_kvcache.register_fake # type: ignore
def _(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return torch.empty_like(decode_query)
class FlashAttentionBackend(AttentionBackend):
......@@ -41,6 +143,10 @@ class FlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......@@ -196,6 +302,51 @@ class FlashAttentionMetadata(AttentionMetadata):
)
return self._cached_decode_metadata
def advance_step(self, num_seqs: int, num_queries: int):
"""
Update metadata in-place to advance one decode step.
"""
# GPU in-place update is currently called separately through
# custom_ops.advance_step(). See draft_model_runner. TODO(will): Move
# this logic to the backend.
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)
class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
......@@ -259,7 +410,11 @@ class FlashAttentionMetadataBuilder(
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
......@@ -310,7 +465,8 @@ class FlashAttentionMetadataBuilder(
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
block_tables = torch.from_numpy(input_block_tables).to(
device=device, non_blocking=True)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
......@@ -320,15 +476,15 @@ class FlashAttentionMetadataBuilder(
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
......@@ -344,10 +500,6 @@ class FlashAttentionMetadataBuilder(
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)
return FlashAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
......@@ -516,7 +668,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
out = flash_attn_varlen_func(
out = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key,
v=value,
......@@ -536,7 +688,8 @@ class FlashAttentionImpl(AttentionImpl):
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
output[:num_prefill_tokens] = flash_attn_varlen_func(
output[:
num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
......@@ -553,7 +706,8 @@ class FlashAttentionImpl(AttentionImpl):
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[num_prefill_tokens:] = flash_attn_with_kvcache(
output[
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
......@@ -562,6 +716,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
).squeeze(1)
# Reshape the output tensor.
......
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func
import vllm.attention.backends.flash_attn # noqa
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import torch
......@@ -16,12 +21,13 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
AttentionState, AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
......@@ -45,6 +51,10 @@ class FlashInferBackend(AttentionBackend):
def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
return FlashInferMetadataBuilder
@staticmethod
def get_state_cls() -> Type["FlashInferState"]:
return FlashInferState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......@@ -74,6 +84,162 @@ class FlashInferBackend(AttentionBackend):
return [64, 128, 256]
class FlashInferState(AttentionState):
def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False
self._workspace_buffer = None
self._decode_wrapper = None
self._prefill_wrapper = None
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.runner.device)
return self._workspace_buffer
def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD")
return self._prefill_wrapper
def _get_decode_wrapper(self):
if self._decode_wrapper is None:
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
(1, 2, 4, 8)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper
@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_decode_wrapper = None
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
self._graph_decode_workspace_buffer = self._get_workspace_buffer()
self._graph_indices_buffer = torch.empty(
max_batch_size * self.runner.cache_config.num_gpu_blocks,
dtype=torch.int32,
device=self.runner.device)
self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
dtype=torch.int32,
device=self.runner.device)
self._graph_last_page_len_buffer = torch.empty(
max_batch_size, dtype=torch.int32, device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
del self._graph_decode_workspace_buffer
del self._graph_indices_buffer
del self._graph_indptr_buffer
del self._graph_last_page_len_buffer
del self._graph_decode_wrapper
def graph_clone(self, batch_size: int):
assert self._is_graph_capturing
state = self.__class__(self.runner)
state._workspace_buffer = self._graph_decode_workspace_buffer
state._decode_wrapper = self._graph_decode_wrapper
state._prefill_wrapper = self._get_prefill_wrapper()
return state
def graph_capture_get_metadata_for_batch(self, batch_size: int):
assert self._is_graph_capturing
_indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
_last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
(1, 2, 4, 8)
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
paged_kv_indptr_tensor_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
paged_kv_indices_tensor_host = torch.arange(0,
batch_size,
dtype=torch.int32)
paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
self.runner.block_size,
dtype=torch.int32)
query_start_loc_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size],
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
block_tables=self._graph_block_tables,
paged_kv_indptr=paged_kv_indptr_tensor_host,
paged_kv_indices=paged_kv_indices_tensor_host,
paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=self.runner.model_config.get_head_size(),
page_size=self.runner.block_size,
seq_start_loc=None,
query_start_loc=query_start_loc_host,
device=self.runner.device,
data_type=kv_cache_dtype,
use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None)
attn_metadata.begin_forward()
return attn_metadata
def get_graph_input_buffers(self, attn_metadata):
return {
"slot_mapping": attn_metadata.slot_mapping,
}
def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
return
def begin_forward(self, model_input):
assert not self._is_graph_capturing
state = self
if model_input.attn_metadata.use_cuda_graph:
batch_size = model_input.input_tokens.shape[0]
state = (self.runner.graph_runners[model_input.virtual_engine]
[batch_size].attn_state)
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
)
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
model_input.attn_metadata.begin_forward()
@dataclass
class FlashInferMetadata(AttentionMetadata):
# Maximum sequence length among prefill batch. 0 if there are decoding
......@@ -116,6 +282,7 @@ class FlashInferMetadata(AttentionMetadata):
# The data type of the paged kv cache
data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
is_profile_run: bool = False
def __post_init__(self):
# Refer to
......@@ -139,11 +306,11 @@ class FlashInferMetadata(AttentionMetadata):
assert self.paged_kv_last_page_len is not None
batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0
# The prefill stage does not read kv cache.
# Both paged_kv_indices and paged_kv_last_page_len are empty.
# paged_kv_indptr is a zero tensor with size batch_size + 1.
self.paged_kv_indptr = torch.zeros(batch_size + 1,
device=self.device)
# We will use flash attention for profiling to
# determine the number of blocks. Therefore,
# we don't need to prepare the input for flashinfer for profile run.
if not self.is_profile_run:
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
......@@ -244,6 +411,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# paged_kv_last_page_len is the length of the last page of each request
self.paged_kv_last_page_len: List[int] = []
self.is_profile_run: bool = False
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
......@@ -300,6 +469,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# and paged_kv_last_page_len for profile run because we will
# create dummy inputs.
if is_profile_run:
self.is_profile_run = is_profile_run
return
block_table = block_tables[seq_id]
......@@ -356,7 +526,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True)
last_paged_kv_indptr = self.paged_kv_indptr[-1]
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
......@@ -371,12 +542,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
assert device is not None
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
......@@ -392,10 +564,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)
if len(self.paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu",
......@@ -432,7 +600,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
query_start_loc=query_start_loc,
device=device,
data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph)
use_cuda_graph=use_captured_graph,
is_profile_run=self.is_profile_run)
class FlashInferImpl(AttentionImpl):
......@@ -516,7 +685,7 @@ class FlashInferImpl(AttentionImpl):
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache is None:
output = flash_attn_varlen_func(
output = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key,
v=value,
......
......@@ -8,6 +8,7 @@ import torch
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
......@@ -28,6 +29,10 @@ class IpexAttnBackend(AttentionBackend):
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
return IpexAttnMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......
from dataclasses import dataclass
from typing import List, Tuple
from typing import List, Tuple, Type
import openvino as ov
import torch
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.backends.utils import CommonAttentionState
class OpenVINOAttentionBackend(AttentionBackend):
......@@ -24,6 +25,10 @@ class OpenVINOAttentionBackend(AttentionBackend):
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
return OpenVINOAttentionMetadata(*args, **kwargs)
......
......@@ -6,6 +6,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
class PallasAttentionBackend(AttentionBackend):
......@@ -18,6 +19,10 @@ class PallasAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["PallasMetadata"]:
return PallasMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......
......@@ -7,7 +7,8 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
......@@ -33,6 +34,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
return ROCmFlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......@@ -508,6 +513,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
......@@ -517,6 +523,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
k_scale,
v_scale,
)
if decode_meta := attn_metadata.decode_metadata:
......
......@@ -8,6 +8,7 @@ from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import is_cpu
......@@ -34,6 +35,10 @@ class TorchSDPABackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]:
return TorchSDPAMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......
"""Attention backend utils"""
from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
import numpy as np
import torch
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.utils import make_tensor_with_pad
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase
# Error string(s) for encoder/decoder
# unsupported attention scenarios
......@@ -13,6 +19,10 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
PAD_SLOT_ID = -1
# Switch to numpy implementation of compute_slot_mapping
# if we have at least this many elements. Could be tuned further.
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
......@@ -46,6 +56,29 @@ def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
return start_idx
def _compute_slot_mapping_python(slot_mapping: List[int],
block_table: List[int], range_start: int,
range_end: int, block_size: int):
for i in range(range_start, range_end):
block_number = block_table[i // block_size]
block_offset = i % block_size
slot = block_number * block_size + block_offset
slot_mapping.append(slot)
def _compute_slot_mapping_numpy(slot_mapping: List[int],
block_table: List[int], range_start: int,
range_end: int, block_size: int):
block_table_array = np.array(block_table)
idx = np.arange(range_start, range_end)
block_offset = idx % block_size
idx //= block_size
seq_slot_mapping_array = block_table_array[idx]
seq_slot_mapping_array *= block_size
seq_slot_mapping_array += block_offset
slot_mapping.extend(seq_slot_mapping_array)
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
seq_id: int, seq_len: int, context_len: int,
start_idx: int, block_size: int,
......@@ -67,13 +100,22 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
padding_mask_len = max(0, start_idx - context_len)
slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
range_start = max(start_idx, context_len)
range_end = seq_len
numel = range_end - range_start
block_table = block_tables[seq_id]
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
for i in range(max(start_idx, context_len), seq_len):
block_number = block_table[i // block_size]
block_offset = i % block_size
slot = block_number * block_size + block_offset
slot_mapping.append(slot)
# numpy implementation will be faster than python if we have
# many elements, otherwise it will be slower.
if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
_compute_slot_mapping_python(slot_mapping, block_table, range_start,
range_end, block_size)
else:
_compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
range_end, block_size)
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
......@@ -181,7 +223,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
......@@ -191,15 +234,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
)
assert max_query_len > 0, "query_lens: {}".format(query_lens)
context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
......@@ -215,10 +258,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)
return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
......@@ -235,3 +274,69 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class CommonAttentionState(AttentionState):
def __init__(self, runner: "ModelRunnerBase"):
self.runner = runner
self._is_graph_capturing = False
@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
def graph_clone(self, batch_size: int) -> "CommonAttentionState":
assert self._is_graph_capturing
return self.__class__(self.runner)
def graph_capture_get_metadata_for_batch(self, batch_size: int):
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
)
return attn_metadata
def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]:
return {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
def prepare_graph_input_buffers(self, input_buffers,
attn_metadata) -> None:
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
def begin_forward(self, model_input) -> None:
return
......@@ -11,7 +11,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
......@@ -37,6 +38,10 @@ class XFormersBackend(AttentionBackend):
def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
return XFormersMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......@@ -604,6 +609,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
......@@ -613,6 +619,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window,
k_scale,
v_scale,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
......
......@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
......
......@@ -90,6 +90,7 @@ class PagedAttention:
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
......
......@@ -256,6 +256,7 @@ class PagedAttention:
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
......@@ -265,6 +266,8 @@ class PagedAttention:
max_query_len: int,
alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int],
k_scale: float,
v_scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)
context_attention_fwd(
......@@ -272,6 +275,7 @@ class PagedAttention:
key,
value,
output,
kv_cache_dtype,
key_cache,
value_cache,
block_tables,
......@@ -280,6 +284,8 @@ class PagedAttention:
seq_lens_tensor,
context_lens,
max_query_len,
k_scale,
v_scale,
alibi_slopes,
sliding_window,
)
......
......@@ -18,6 +18,8 @@ if triton.__version__ >= "2.1.0":
V_cache,
B_Loc,
sm_scale,
k_scale,
v_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
......@@ -117,11 +119,16 @@ if triton.__version__ >= "2.1.0":
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
k_load = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * k_scale).to(q.dtype)
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N]
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
......@@ -161,12 +168,16 @@ if triton.__version__ >= "2.1.0":
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
v_load = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0) # [N,D]
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * v_scale).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc += tl.dot(p, v)
# # update m_i and l_i
l_i = l_i_new
......@@ -225,8 +236,8 @@ if triton.__version__ >= "2.1.0":
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_query_len),
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
......@@ -336,7 +347,6 @@ if triton.__version__ >= "2.1.0":
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
......@@ -442,6 +452,8 @@ if triton.__version__ >= "2.1.0":
V_cache,
B_Loc,
sm_scale,
k_scale,
v_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
......@@ -537,11 +549,16 @@ if triton.__version__ >= "2.1.0":
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
k_load = tl.load(K_cache + off_k,
mask=dim_mask[:, None] &
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0) # [D,N]
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * k_scale).to(q.dtype)
else:
k = k_load
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
......@@ -573,12 +590,16 @@ if triton.__version__ >= "2.1.0":
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
v_load = tl.load(V_cache + off_v,
mask=dim_mask[None, :] &
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0)
if v_load.dtype.is_fp8():
v = (v_load.to(tl.float32) * v_scale).to(q.dtype)
else:
v = v_load
p = p.to(v.dtype)
acc += tl.dot(p, v, allow_tf32=False)
# update m_i and l_i
l_i = l_i_new
......@@ -650,8 +671,8 @@ if triton.__version__ >= "2.1.0":
((start_n + offs_n[:, None]) <
cur_batch_seq_len - cur_batch_ctx_len),
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v, allow_tf32=False)
# update m_i and l_i
l_i = l_i_new
......@@ -675,6 +696,7 @@ if triton.__version__ >= "2.1.0":
k,
v,
o,
kv_cache_dtype: str,
k_cache,
v_cache,
b_loc,
......@@ -682,17 +704,41 @@ if triton.__version__ >= "2.1.0":
b_seq_len,
b_ctx_len,
max_input_len,
k_scale: float = 1.0,
v_scale: float = 1.0,
alibi_slopes=None,
sliding_window=None):
cap = current_platform.get_device_capability()
BLOCK = 32 if cap[0] >= 8 else 32
NUM_WARPS = 8
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
if q.dtype is torch.float32:
BLOCK = BLOCK // 2
# Conversion of FP8 Tensor from uint8 storage to
# appropriate torch.dtype for interpretation by Triton
if "fp8" in kv_cache_dtype:
assert (k_cache.dtype == torch.uint8)
assert (v_cache.dtype == torch.uint8)
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
target_dtype = torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
target_dtype = torch.float8_e5m2
else:
raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype)
k_cache = k_cache.view(target_dtype)
v_cache = v_cache.view(target_dtype)
if (k_cache.dtype == torch.uint8
or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"):
raise ValueError("kv_cache_dtype='auto' unsupported for\
FP8 KV Cache prefill kernel")
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
......@@ -709,7 +755,6 @@ if triton.__version__ >= "2.1.0":
if sliding_window is None or sliding_window <= 0:
sliding_window = 0
num_warps = 8 if Lk <= 64 else 4
if alibi_slopes is not None:
_fwd_kernel_alibi[grid](
q,
......@@ -719,6 +764,8 @@ if triton.__version__ >= "2.1.0":
v_cache,
b_loc,
sm_scale,
k_scale,
v_scale,
b_start_loc,
b_seq_len,
b_ctx_len,
......@@ -757,7 +804,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL=Lk,
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_warps=NUM_WARPS,
num_stages=1,
)
return
......@@ -770,6 +817,8 @@ if triton.__version__ >= "2.1.0":
v_cache,
b_loc,
sm_scale,
k_scale,
v_scale,
b_start_loc,
b_seq_len,
b_ctx_len,
......@@ -807,7 +856,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
SLIDING_WINDOW=sliding_window,
num_warps=num_warps,
num_warps=NUM_WARPS,
num_stages=1,
)
return
import enum
import os
from contextlib import contextmanager
from functools import lru_cache
from typing import Optional, Type
from typing import Generator, Optional, Type
import torch
......@@ -8,7 +10,7 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu
logger = init_logger(__name__)
......@@ -24,6 +26,66 @@ class _Backend(enum.Enum):
IPEX = enum.auto()
def backend_name_to_enum(backend_name: str) -> _Backend:
assert backend_name is not None
backend_members = _Backend.__members__
if backend_name not in backend_members:
raise ValueError(f"Invalid attention backend '{backend_name}'. "
f"Available backends: {', '.join(backend_members)} "
"(case-sensitive).")
return _Backend[backend_name]
def get_env_variable_attn_backend() -> Optional[_Backend]:
'''
Get the backend override specified by the vLLM attention
backend environment variable, if one is specified.
Returns:
* _Backend enum value if an override is specified
* None otherwise
'''
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
return (None
if backend_name is None else backend_name_to_enum(backend_name))
# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: Optional[_Backend] = None
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
'''
Force all attention operations to use a specified backend.
Passing `None` for the argument re-enables automatic
backend selection.,
Arguments:
* attn_backend: backend selection (None to revert to auto)
'''
global forced_attn_backend
forced_attn_backend = attn_backend
def get_global_forced_attn_backend() -> Optional[_Backend]:
'''
Get the currently-forced choice of attention backend,
or None if auto-selection is currently enabled.
'''
return forced_attn_backend
@lru_cache(maxsize=None)
def get_attn_backend(
num_heads: int,
......@@ -101,16 +163,20 @@ def which_attn_to_use(
# Default case.
selected_backend = _Backend.FLASH_ATTN
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
backend_by_global_setting: Optional[_Backend] = (
get_global_forced_attn_backend())
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
# Check the environment variable and override if specified
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
backend_members = _Backend.__members__
if backend_by_env_var not in backend_members:
raise ValueError(
f"Invalid attention backend '{backend_by_env_var}'. "
f"Available backends: {', '.join(backend_members)} "
"(case-sensitive).")
selected_backend = _Backend[backend_by_env_var]
selected_backend = backend_name_to_enum(backend_by_env_var)
if is_cpu():
if selected_backend != _Backend.TORCH_SDPA:
......@@ -127,7 +193,7 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX
if is_tpu():
if current_platform.is_tpu():
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
......@@ -193,3 +259,35 @@ def which_attn_to_use(
selected_backend = _Backend.XFORMERS
return selected_backend
@contextmanager
def global_force_attn_backend_context_manager(
attn_backend: _Backend) -> Generator[None, None, None]:
'''
Globally force a vLLM attention backend override within a
context manager, reverting the global attention backend
override to its prior state upon exiting the context
manager.
Arguments:
* attn_backend: attention backend to force
Returns:
* Generator
'''
# Save the current state of the global backend override (if any)
original_value = get_global_forced_attn_backend()
# Globally force the new backend override
global_force_attn_backend(attn_backend)
# Yield control back to the enclosed code block
try:
yield
finally:
# Revert the original global backend override, if any
global_force_attn_backend(original_value)
"""Token blocks."""
from typing import List
from typing import List, Optional
from vllm.utils import Device
......@@ -37,5 +37,47 @@ class PhysicalTokenBlock:
f'computed={self.computed})')
# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]
class BlockTable:
"""Holds a list of blocks with caching of their associated block_ids
"""
def __init__(self, blocks: Optional[List[PhysicalTokenBlock]] = None):
self._blocks: List[PhysicalTokenBlock] = []
self._block_ids: List[int] = []
if blocks is not None:
for block in blocks:
self.append(block)
def append(self, block: PhysicalTokenBlock):
self._blocks.append(block)
self._block_ids.append(block.block_number)
def __len__(self) -> int:
return len(self._blocks)
def __getitem__(self, key):
return self._blocks[key]
def __setitem__(self, key, value):
if isinstance(key, slice):
blocks = value
self._blocks[key] = blocks
self._block_ids[key] = [b.block_number for b in blocks]
else:
block = value
self._blocks[key] = block
self._block_ids[key] = block.block_number
def reset(self):
self._blocks = []
self._block_ids = []
def copy(self) -> "BlockTable":
return BlockTable(self._blocks)
def list(self) -> List[PhysicalTokenBlock]:
return self._blocks
def ids(self) -> List[int]:
return self._block_ids
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment