Unverified Commit 2fa4623d authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Core] Refactor _prepare_model_input_tensors - take 2 (#6164)

parent a9a2e74d
......@@ -3,7 +3,7 @@ from typing import List, Tuple, Type
import torch
from vllm.attention import AttentionMetadata
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
......@@ -26,6 +26,10 @@ class MockAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]:
return AttentionMetadata
@staticmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise AttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
AttentionMetadata,
AttentionMetadataBuilder)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
......@@ -7,6 +8,7 @@ __all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"AttentionMetadataBuilder",
"Attention",
"get_attn_backend",
]
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
TypeVar)
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar)
import torch
if TYPE_CHECKING:
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase
class AttentionType(Enum):
DECODER = auto() # Decoder attention between previous layer Q/K/V
......@@ -35,6 +39,16 @@ class AttentionBackend(ABC):
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)
@staticmethod
@abstractmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError
@classmethod
def make_metadata_builder(cls, *args,
**kwargs) -> "AttentionMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)
@staticmethod
@abstractmethod
def get_kv_cache_shape(
......@@ -110,6 +124,33 @@ class AttentionMetadata:
T = TypeVar("T", bound=AttentionMetadata)
class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""
@abstractmethod
def __init__(self, input_builder) -> None:
raise NotImplementedError
@abstractmethod
def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata",
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata and update
corresponding fields (in Python objects).
"""
raise NotImplementedError
@abstractmethod
def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int],
query_lens: List[int], cuda_graph_pad_size: int,
batch_size: int) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError
class AttentionImpl(ABC, Generic[T]):
@abstractmethod
......
......@@ -5,6 +5,7 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention
......@@ -93,6 +94,10 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]:
return BlocksparseFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
return BlocksparseFlashAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......@@ -244,6 +249,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
return self._cached_decode_metadata
class BlocksparseFlashAttentionMetadataBuilder(
CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):
_metadata_cls = BlocksparseFlashAttentionMetadata
class BlocksparseFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
......
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)
class FlashAttentionBackend(AttentionBackend):
......@@ -28,6 +39,10 @@ class FlashAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......@@ -184,6 +199,170 @@ class FlashAttentionMetadata(AttentionMetadata):
return self._cached_decode_metadata
class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
curr_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size,
seq_group_metadata.block_tables)
def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors."""
device = runner.device
use_captured_graph = cuda_graph_pad_size != -1
logits_soft_cap = getattr(runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad(
self.block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)
return FlashAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class FlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
......
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
......@@ -14,7 +14,18 @@ import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad
if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)
class FlashInferBackend(AttentionBackend):
......@@ -31,6 +42,10 @@ class FlashInferBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashInferMetadata
@staticmethod
def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
return FlashInferMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......@@ -188,6 +203,225 @@ class FlashInferMetadata(AttentionMetadata):
return self
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
self.paged_kv_indices: List[int] = []
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
self.paged_kv_indptr: List[int] = [0]
# paged_kv_last_page_len is the length of the last page of each request
self.paged_kv_last_page_len: List[int] = []
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables
computed_block_nums = seq_group_metadata.computed_block_nums
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
curr_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
block_table = computed_block_nums
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
self.block_tables.append(block_table)
is_profile_run = is_block_tables_empty(block_tables)
# Compute slot mapping.
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size,
seq_group_metadata.block_tables)
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
# and paged_kv_last_page_len for profile run because we will
# create dummy inputs.
if is_profile_run:
return
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
block_table = block_tables[seq_id]
self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound)
last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
self.paged_kv_last_page_len.append(last_page_len)
def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
cuda_graph_pad_size: int, batch_size: int):
device = runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
last_paged_kv_indptr = self.paged_kv_indptr[-1]
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
cuda_graph_pad_size)
self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad(
self.block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)
logits_soft_cap = getattr(runner.model_config.hf_config,
"attn_logit_softcapping", None)
if len(self.paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu",
dtype=torch.int)
paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
device="cpu",
dtype=torch.int)
paged_kv_last_page_len_tensor = torch.tensor(
self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
else:
paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None
kv_cache_dtype = get_kv_cache_torch_dtype(runner.kv_cache_dtype,
runner.model_config.dtype)
return FlashInferMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
max_prefill_seq_len=max_prefill_seq_len,
block_tables=block_tables,
paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
num_qo_heads=runner.model_config.get_num_attention_heads(
runner.parallel_config),
num_kv_heads=runner.model_config.get_num_kv_heads(
runner.parallel_config),
head_dim=runner.model_config.get_head_size(),
page_size=self.block_size,
seq_start_loc=seq_start_loc,
query_start_loc=query_start_loc,
device=device,
data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph,
logits_soft_cap=logits_soft_cap)
class FlashInferImpl(AttentionImpl):
def __init__(
......
......@@ -7,6 +7,7 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
......@@ -28,6 +29,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]:
return ROCmFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
return ROCmFlashAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......@@ -166,6 +171,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
return self._cached_decode_metadata
class ROCmFlashAttentionMetadataBuilder(
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
_metadata_cls = ROCmFlashAttentionMetadata
def _make_alibi_bias(alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: Optional[List[int]],
......
"""Attention backend utils"""
from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
import torch
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
"with encoder/decoder models.")
PAD_SLOT_ID = -1
if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)
def is_block_tables_empty(block_tables: Union[None, Dict]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if block_tables is None:
return True
if isinstance(block_tables, dict) and all(
value is None for value in block_tables.values()):
return True
return False
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
context_len: int, sliding_window: int,
use_v2_block_manager: bool):
"""
Compute the start index of slot mapping.
"""
start_idx = 0
if is_prompt and sliding_window is not None:
assert use_v2_block_manager or context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager")
# When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx = max(0, query_len - sliding_window)
return start_idx
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
seq_id: int, seq_len: int, context_len: int,
start_idx: int, block_size: int,
block_tables: Dict[int, List[int]]):
"""
Compute slot mapping.
"""
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping.extend([PAD_SLOT_ID] * seq_len)
return
# Mask the [0, start_idx) tokens of the prompt with
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
block_table = block_tables[seq_id]
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
for i in range(max(start_idx, context_len), seq_len):
block_number = block_table[i // block_size]
block_offset = i % block_size
slot = block_number * block_size + block_offset
slot_mapping.append(slot)
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
_metadata_cls: Type[TAttentionMetadata]
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int], prefix_cache_hit,
chunked_prefill_enabled):
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables
computed_block_nums = seq_group_metadata.computed_block_nums
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
curr_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
block_table = computed_block_nums
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size,
seq_group_metadata.block_tables)
def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int],
query_lens: List[int], cuda_graph_pad_size: int,
batch_size: int):
device = runner.device
use_captured_graph = cuda_graph_pad_size != -1
logits_soft_cap = getattr(runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap "
"(i.e., Gemma-2). Otherwise, the output might be wrong. "
"Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad(
self.block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, "query_lens: {}".format(query_lens)
context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)
return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
......@@ -11,6 +11,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
......@@ -32,6 +33,10 @@ class XFormersBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]:
return XFormersMetadata
@staticmethod
def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
return XFormersMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......@@ -362,6 +367,11 @@ def _get_seq_len_block_table_args(
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
_metadata_cls = XFormersMetadata
class XFormersImpl(AttentionImpl[XFormersMetadata]):
"""
If the input tensors contain prompt tokens, the layout is as follows:
......
......@@ -7,6 +7,7 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
logger = init_logger(__name__)
......@@ -136,7 +137,7 @@ def which_attn_to_use(
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if torch.cuda.get_device_capability()[0] != 9:
if current_platform.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
......@@ -145,7 +146,7 @@ def which_attn_to_use(
# FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN:
if torch.cuda.get_device_capability()[0] < 8:
if current_platform.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "
......
This diff is collapsed.
......@@ -113,6 +113,21 @@ class ModelRunnerInputBase(ABC):
raise NotImplementedError
class ModelRunnerInputBuilderBase(ABC, Generic[T]):
"""A builder to create ModelRunnerInputBase objects.
"""
@abstractmethod
def add_seq_group(self, seq_group_metadata):
"""TBA"""
raise NotImplementedError
@abstractmethod
def build(self, *args, **kwargs) -> T:
"""Build metadata with on-device tensors."""
raise NotImplementedError
class ModelRunnerBase(ABC, Generic[T]):
"""
Model runner interface that abstracts a particular hardware and/or type of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment