Unverified Commit 65bf2ac1 authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode...

[Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API (#4681)

This PR combines prepare_prompt and prepare_decode into a single API. This PR also coelsce the attn metadata for prefill/decode to a single class and allow to slice them when running attn backend.

It also refactors subquery_start_loc which was not refactored in the previous PR
parent 8a7cc254
......@@ -58,19 +58,25 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices.append(selected_token_start_idx +
seq_len - 1)
selected_token_start_idx += seq_len
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = model_input.slot_mapping
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.is_prompt is True
assert attn_metadata.num_prefills > 0
assert attn_metadata.num_decode_tokens == 0
assert torch.allclose(
attn_metadata.seq_lens_tensor,
torch.tensor(seq_lens, device=device, dtype=torch.int))
assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.max_seq_len == max(seq_lens)
assert attn_metadata.max_prefill_seq_len == max(seq_lens)
assert attn_metadata.max_decode_seq_len == 0
# Test subquery start locs.
start_idx = 0
......@@ -79,11 +85,11 @@ def test_prepare_prompt(batch_size):
start_idx += seq_len
start_loc.append(start_idx)
assert torch.allclose(
attn_metadata.subquery_start_loc,
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
# Test seq start locs. Note that for normal prefill it is
# equivalent to subquery_start_loc.
# equivalent to query_start_loc.
start_idx = 0
seq_start_loc = [start_idx]
for seq_len in seq_lens:
......@@ -123,7 +129,7 @@ def test_prepare_prompt(batch_size):
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
assert input_tokens == input_positions
torch.allclose(input_tokens, input_positions)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
......@@ -144,14 +150,18 @@ def test_prepare_decode_cuda_graph(batch_size):
enable_chunked_prefill=False,
)
seq_lens = []
context_lens = []
seq_group_metadata_list = []
# Assume each seq group finishes prefill.
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = list(range(seq_len))
context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len)
seq_data = list(range(context_len))
seq_data = SequenceData(seq_data)
seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
......@@ -162,18 +172,45 @@ def test_prepare_decode_cuda_graph(batch_size):
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens, model_input.input_positions,
model_input.attn_metadata, model_input.slot_mapping)
assert len(slot_mapping) == len(input_tokens)
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.is_prompt is False
assert attn_metadata.seq_lens is None
assert attn_metadata.subquery_start_loc is None
assert attn_metadata.seq_start_loc is None
assert attn_metadata.max_seq_len == max(seq_lens)
assert attn_metadata.num_prefills == 0
assert attn_metadata.num_prefill_tokens == 0
seq_lens = [context_len + 1 for context_len in context_lens]
# seq_lens are padded to expected_bs
for _ in range(expected_bs - len(seq_lens)):
seq_lens.append(1)
assert attn_metadata.seq_lens == seq_lens
start_idx = 0
start_loc = [start_idx]
for _ in context_lens:
# decode has only 1 token for query.
start_idx += 1
start_loc.append(start_idx)
assert torch.allclose(
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
start_idx = 0
seq_start_loc = [start_idx]
for seq_len in seq_lens:
start_idx += seq_len
seq_start_loc.append(start_idx)
assert torch.allclose(
attn_metadata.seq_start_loc,
torch.tensor(seq_start_loc, dtype=torch.int32, device=device))
assert torch.allclose(
attn_metadata.context_lens_tensor,
torch.tensor(context_lens, dtype=torch.int, device=device))
assert attn_metadata.max_decode_seq_len == max(seq_lens)
assert torch.allclose(
attn_metadata.seq_lens_tensor[:len(seq_lens)],
torch.tensor(seq_lens, dtype=torch.int, device=device))
......@@ -185,23 +222,23 @@ def test_prepare_decode_cuda_graph(batch_size):
# It is padded up to
assert attn_metadata.block_tables.shape[1] == (
model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is True
assert len(input_tokens) == expected_bs
assert len(input_positions) == expected_bs
assert input_tokens == input_positions
torch.allclose(input_tokens, input_positions)
# Verify Sampling
expected_selected_token_indices = []
selected_token_start_idx = 0
for seq_len in seq_lens:
for _ in context_lens:
expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
# query lens is all 1 for decode.
query_lens=[1 for _ in range(len(context_lens))],
device=model_runner.device,
pin_memory=model_runner.pin_memory)
actual = sampling_metadata.selected_token_indices
......@@ -220,15 +257,27 @@ def test_empty_seq_group():
enforce_eager=False,
)
seq_group_metadata_list = []
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, slot_mapping,
return_seq_lens) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
model_input.seq_lens,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
......@@ -285,9 +334,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# Add decode requests
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(seq_len))
context_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(context_len))
seq_data = SequenceData(prompt_toks)
seq_data.append_token_id(1, 0)
seq_data.update_num_computed_tokens(context_len)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
......@@ -308,23 +359,17 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
assert len(attn_metadata.slot_mapping) == len(input_tokens)
assert len(input_positions) == len(input_tokens)
assert attn_metadata.num_prefills == prefill_batch_size
if enforce_eager:
assert attn_metadata.num_decode_tokens == decode_batch_size
else:
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
decode_batch_size)
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
prefill_meta = model_runner._prepare_prompt(
prefill_metadata_list).attn_metadata
decode_meta = model_runner._prepare_decode(
decode_metadata_list).attn_metadata
attn_metadata = model_runner._prepare_model_input(
seq_group_metadata_list).attn_metadata
for attr_expected, attr_actual in zip(vars(prefill_meta),
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
vars(prefill_meta_actual)):
assert attr_expected[1] == attr_actual[1]
for attr_expected, attr_actual in zip(vars(decode_meta),
for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
vars(decode_meta_actual)):
assert attr_expected[1] == attr_actual[1]
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataPerStage)
AttentionMetadata)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
......@@ -8,6 +7,6 @@ __all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"AttentionMetadataPerStage",
"Attention",
"get_attn_backend",
]
......@@ -21,7 +21,7 @@ class AttentionBackend(ABC):
@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage":
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError
@staticmethod
......@@ -53,8 +53,34 @@ class AttentionBackend(ABC):
@dataclass
class AttentionMetadataPerStage:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
class AttentionMetadata:
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
"""Return the attention metadata that's required to run prefill
attention."""
pass
@property
@abstractmethod
def decode_metadata(self) -> Optional["AttentionMetadata"]:
"""Return the attention metadata that's required to run decode
attention."""
pass
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
......@@ -70,40 +96,10 @@ class AttentionMetadataPerStage:
}
T = TypeVar("T", bound=AttentionMetadataPerStage)
@dataclass
class AttentionMetadata(Generic[T]):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# The attention metadata for prefill requests in a batch.
# None if there's no prefill requests in a batch.
prefill_metadata: Optional[T]
# The attention metadata for decode requests in a batch.
# None if there's no decode requests in a batch.
decode_metadata: Optional[T]
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
def __post_init__(self):
if self.num_prefill_tokens > 0:
assert self.num_prefills > 0
assert self.prefill_metadata is not None
if self.num_decode_tokens > 0:
assert self.decode_metadata is not None
T = TypeVar("T", bound=AttentionMetadata)
class AttentionImpl(ABC):
class AttentionImpl(ABC, Generic[T]):
@abstractmethod
def __init__(
......@@ -125,7 +121,7 @@ class AttentionImpl(ABC):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_metadata: T,
kv_scale: float = 1.0,
) -> torch.Tensor:
raise NotImplementedError
......@@ -11,8 +11,7 @@ import torch
from vllm_flash_attn import flash_attn_varlen_func
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
AttentionMetadata)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
......@@ -58,8 +57,7 @@ class FlashAttentionBackend(AttentionBackend):
@dataclass
class FlashAttentionMetadata(AttentionMetadataPerStage,
PagedAttentionMetadata):
class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
......@@ -67,9 +65,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
......@@ -84,14 +79,18 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch.
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
query_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
......@@ -105,6 +104,70 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
_cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.block_tables is not None
assert self.seq_start_loc is not None
self._cached_prefill_metadata = FlashAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.block_tables is not None
assert self.seq_lens_tensor is not None
self._cached_decode_metadata = FlashAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
class FlashAttentionImpl(AttentionImpl):
"""
......@@ -168,7 +231,7 @@ class FlashAttentionImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
attn_metadata: FlashAttentionMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
......@@ -228,8 +291,8 @@ class FlashAttentionImpl(AttentionImpl):
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
......@@ -249,7 +312,7 @@ class FlashAttentionImpl(AttentionImpl):
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
......@@ -264,7 +327,7 @@ class FlashAttentionImpl(AttentionImpl):
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
decode_meta.max_decode_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
......
......@@ -8,8 +8,7 @@ from vllm_flash_attn import flash_attn_varlen_func
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
AttentionMetadata)
class FlashInferBackend(AttentionBackend):
......@@ -56,9 +55,10 @@ class FlashInferBackend(AttentionBackend):
@dataclass
class FlashInferMetadata(AttentionMetadataPerStage):
is_prompt: bool
class FlashInferMetadata(AttentionMetadata):
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
use_cuda_graph: bool = False
......@@ -67,7 +67,6 @@ class FlashInferMetadata(AttentionMetadataPerStage):
# Metadata for the prefill stage since we still
# use flash attention for prefill.
seq_start_loc: Optional[torch.Tensor] = None
max_seq_len: Optional[int] = None
block_tables: Optional[torch.Tensor] = None
# Metadata for the decode stage
......@@ -113,7 +112,8 @@ class FlashInferMetadata(AttentionMetadataPerStage):
# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if not self.is_prompt:
if self.num_prefills == 0:
assert self.num_decode_tokens > 0
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD")
self.decode_wrapper.begin_forward(
......@@ -138,6 +138,24 @@ class FlashInferMetadata(AttentionMetadataPerStage):
skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields)
@property
def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
assert self.num_prefills > 0
return self
return None
@property
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefills > 0:
assert self.num_decode_tokens == 0
return None
return self
class FlashInferImpl(AttentionImpl):
......@@ -172,7 +190,7 @@ class FlashInferImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[FlashInferMetadata],
attn_metadata: FlashInferMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
assert kv_scale == 1.0
......@@ -208,8 +226,8 @@ class FlashInferImpl(AttentionImpl):
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
......
......@@ -6,8 +6,7 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
AttentionMetadata)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
......@@ -56,8 +55,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
@dataclass
class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
PagedAttentionMetadata):
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
......@@ -65,9 +63,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
......@@ -82,14 +77,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch.
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
query_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
......@@ -102,6 +101,69 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.block_tables is not None
assert self.seq_start_loc is not None
self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.block_tables is not None
assert self.seq_lens_tensor is not None
self._cached_decode_metadata = ROCmFlashAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
class ROCmFlashAttentionImpl(AttentionImpl):
......@@ -198,7 +260,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
attn_metadata: ROCmFlashAttentionMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
......@@ -266,8 +328,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
None,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_seq_len,
prefill_meta.max_seq_len,
prefill_meta.max_prefill_seq_len,
prefill_meta.max_prefill_seq_len,
True,
self.scale,
)
......@@ -290,8 +352,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
......@@ -308,7 +370,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
......@@ -324,7 +386,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
decode_meta.max_decode_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
......
......@@ -7,8 +7,7 @@ import torch
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
AttentionMetadata)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
......@@ -54,8 +53,7 @@ class TorchSDPABackend(AttentionBackend):
@dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
AttentionMetadataPerStage):
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
......@@ -72,8 +70,26 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None
@property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
assert self.num_prefills > 0
return self
class TorchSDPABackendImpl(AttentionImpl):
return None
@property
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefills > 0:
assert self.num_decode_tokens == 0
return None
return self
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
def __init__(
self,
......@@ -200,7 +216,7 @@ class TorchSDPABackendImpl(AttentionImpl):
value_cache,
attn_metadata.block_tables,
attn_metadata.seq_lens_tensor,
attn_metadata.max_seq_len,
attn_metadata.max_decode_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
......
......@@ -9,8 +9,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
LowerTriangularMaskWithTensorBias)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
AttentionMetadata)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
......@@ -59,7 +58,7 @@ class XFormersBackend(AttentionBackend):
@dataclass
class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is
......@@ -67,9 +66,6 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
......@@ -83,15 +79,19 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch.
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# FIXME: It is for flash attn.
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
query_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
......@@ -105,6 +105,8 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
_cached_prefill_metadata: Optional["XFormersMetadata"] = None
_cached_decode_metadata: Optional["XFormersMetadata"] = None
def __post_init__(self):
# Set during the execution of the first attention op.
......@@ -114,8 +116,68 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[AttentionBias]] = None
@property
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.block_tables is not None
self._cached_prefill_metadata = XFormersMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=None,
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["XFormersMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.block_tables is not None
assert self.seq_lens_tensor is not None
self._cached_decode_metadata = XFormersMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
class XFormersImpl(AttentionImpl):
class XFormersImpl(AttentionImpl[XFormersMetadata]):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
......@@ -176,7 +238,7 @@ class XFormersImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[XFormersMetadata],
attn_metadata: "XFormersMetadata",
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
......@@ -244,7 +306,7 @@ class XFormersImpl(AttentionImpl):
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
......@@ -261,7 +323,7 @@ class XFormersImpl(AttentionImpl):
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
decode_meta.max_decode_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
......
......@@ -4,8 +4,7 @@ from typing import List, Optional
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import (AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
......@@ -57,7 +56,7 @@ class Attention(nn.Module):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
attn_metadata: AttentionMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
......
......@@ -16,8 +16,8 @@ class PagedAttentionMetadata:
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
......@@ -166,7 +166,7 @@ class PagedAttention:
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_query_len: int,
......@@ -182,8 +182,8 @@ class PagedAttention:
key_cache,
value_cache,
block_tables,
# subquery_start_loc is (batch_size + 1,)
subquery_start_loc[:-1],
# query_start_loc is (batch_size + 1,)
query_start_loc[:-1],
seq_lens_tensor,
context_lens,
max_query_len,
......
......@@ -618,6 +618,11 @@ class EngineArgs:
decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
if (model_config.get_sliding_window() is not None
and scheduler_config.chunked_prefill_enabled):
raise ValueError(
"Chunked prefill is not supported with sliding window.")
return EngineConfig(model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
......
......@@ -122,6 +122,7 @@ class RejectionSampler(nn.Module):
draft_token_ids,
bonus_token_ids,
)
return output_token_ids
def _batch_modified_rejection_sampling(
......
......@@ -654,8 +654,9 @@ class SequenceGroupMetadata:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def token_chunk_size(self) -> Optional[int]:
def token_chunk_size(self) -> int:
"""Return the number of tokens to be processed (chunk size)."""
assert self._token_chunk_size is not None
return self._token_chunk_size
......
......@@ -293,21 +293,30 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
prompt_token_ids = seq_data.get_prompt_token_ids()
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data={
new_seq_data_dict = {
target_seq_id:
SequenceData(
prompt_token_ids=prompt_token_ids,
output_token_ids=new_output_token_ids,
),
},
}
# This is a hack. Technically, spec decoding should compute
# num_lookahead slots at one shot, but instead, it expands the batch
# and evaluate one by one right now. context_len is seq_len - 1 because
# the kv cache is filled by a previous batch in the batch expansion.
for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1)
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data=new_seq_data_dict,
sampling_params=seq_group_metadata.sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
token_chunk_size=1,
)
def _split_scoring_output(
......
......@@ -114,6 +114,7 @@ class MultiStepWorker(Worker):
token_logprob = seq_output.logprobs[token_id]
seq.append_token_id(token_id, token_logprob.logprob)
seq.update_num_computed_tokens(1)
def _shallow_copy_inputs(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
......
......@@ -159,12 +159,10 @@ class CPUModelRunner:
is_prompt=True,
seq_lens=seq_lens,
seq_lens_tensor=None,
max_seq_len=None,
max_decode_seq_len=None,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
prefill_metadata=None,
decode_metadata=None,
block_tables=torch.tensor([]),
slot_mapping=slot_mapping,
)
......@@ -213,7 +211,7 @@ class CPUModelRunner:
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
max_seq_len = max(seq_lens)
max_decode_seq_len = max(seq_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
......@@ -243,12 +241,10 @@ class CPUModelRunner:
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_seq_len=max_seq_len,
max_decode_seq_len=max_decode_seq_len,
num_prefill_tokens=0,
num_decode_tokens=len(input_tokens),
num_prefills=0,
prefill_metadata=None,
decode_metadata=None,
block_tables=block_tables,
)
return (
......
......@@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import BatchType, ModelRunner
from vllm.worker.model_runner import ModelRunner
logger = init_logger(__name__)
......@@ -88,85 +88,24 @@ class EmbeddingModelRunner(ModelRunner):
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
Set[LoRARequest], LoRAMapping, torch.Tensor]:
if self.is_driver_worker:
prefill_reqs = []
decode_reqs = []
for seq_group_meta in seq_group_metadata_list:
if seq_group_meta.is_prompt:
prefill_reqs.append(seq_group_meta)
else:
decode_reqs.append(seq_group_meta)
# Prepare input tensors.
(
input_tokens,
input_positions,
prefill_attn_metadata,
prompt_lens,
subquery_lens,
lora_index_mapping,
lora_prompt_mapping,
attn_metadata,
seq_lens,
_,
lora_mapping,
lora_requests,
multi_modal_input,
slot_mapping,
) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
decode_input_positions,
decode_attn_metadata,
decode_lora_index_mapping,
decode_lora_prompt_mapping,
decode_lora_requests,
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
num_prefill_tokens,
num_decode_tokens,
num_prefills,
) = self._prepare_model_input(seq_group_metadata_list)
# Prepare PoolingMetadata
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
prompt_lens)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(prompt_lens)
num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens)
# Coalesce tensors. Note that attn_metadata is currently not
# coalesced for simplicity.
input_tokens.extend(decode_input_tokens)
input_positions.extend(decode_input_positions)
slot_mapping.extend(decode_slot_mapping)
lora_index_mapping.extend(decode_lora_index_mapping)
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
lora_requests.update(decode_lora_requests)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
if self.lora_config:
lora_mapping = LoRAMapping(
lora_index_mapping,
lora_prompt_mapping,
)
else:
lora_mapping = None
# Broadcast the metadata.
# If batch contains both prefill and decode, it sends 2 broadcasts.
# If it only contains 1 type, it triggers a single broadcast.
if (prefill_attn_metadata is not None
and decode_attn_metadata is not None):
batch_type = BatchType.MIXED
elif prefill_attn_metadata is not None:
batch_type = BatchType.PREFILL
else:
batch_type = BatchType.DECODE
seq_lens)
metadata_dict = {
"input_tokens": input_tokens,
......@@ -178,65 +117,26 @@ class EmbeddingModelRunner(ModelRunner):
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
"batch_type": batch_type,
}
if prefill_attn_metadata is not None:
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
else:
assert decode_attn_metadata is not None
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
# Broadcast decode attn metadata for mixed batch type.
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
# We can potentially reduce the overhead by coelescing tensors.
if batch_type == BatchType.MIXED:
assert decode_attn_metadata is not None
metadata_dict = decode_attn_metadata.asdict_zerocopy()
if attn_metadata:
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
slot_mapping = metadata_dict.pop("slot_mapping")
num_prefills = metadata_dict.pop("num_prefills")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input")
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
batch_type = metadata_dict.pop("batch_type")
# Create an attention metadata.
prefill_attn_metadata = None
decode_attn_metadata = None
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
prefill_attn_metadata = self.attn_backend.make_metadata(
if metadata_dict:
attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
attn_metadata = None
pooling_metadata = PoolingMetadata(seq_groups=None,
seq_data=None,
prompt_lens=None)
# if it is a mixed batch, decode attn_metadata is broadcasted
# separately.
if batch_type == BatchType.MIXED:
metadata_dict = broadcast_tensor_dict(src=0)
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
attn_metadata = AttentionMetadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
prefill_metadata=prefill_attn_metadata,
decode_metadata=decode_attn_metadata,
)
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
lora_requests, lora_mapping, multi_modal_input)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment