Unverified Commit 67b4221a authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Core][5/N] Fully working chunked prefill e2e (#3884)

parent 63e7176f
......@@ -173,10 +173,18 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=group)
async_handles = []
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = tensor_dict[key]
torch.distributed.broadcast(tensor, src=src, group=group)
async_handles.append(
torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True))
for async_handle in async_handles:
async_handle.wait()
else:
recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list,
......
......@@ -386,9 +386,8 @@ class EngineArgs:
'prompt latency) before scheduling next prompt.')
parser.add_argument(
'--enable-chunked-prefill',
type=bool,
default=False,
help='If True, the prefill requests can be chunked based on the '
action='store_true',
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument(
......
......@@ -633,7 +633,10 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
self._process_sequence_group_outputs(seq_group, outputs)
# If uncomputed tokens > 0, it means prefill is chunked.
# We don't need to process outputs in that case.
if seq_group.get_num_uncomputed_tokens() == 0:
self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
......
......@@ -267,12 +267,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
embedding_len = self.indices_len[3]
indices = self.embeddings_indices[1][:embedding_len].view_as(x)
full_lora_a_embeddings = F.embedding(
x + indices,
self.lora_a_stacked_2d,
)
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
indices = self.embeddings_indices[0][:embedding_len].view_as(x)
full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask))
......
......@@ -500,7 +500,8 @@ class SequenceGroup:
def get_num_uncomputed_tokens(self) -> int:
num_uncomputed_tokens = 0
for seq in self.get_seqs():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
......
import contextlib
import time
from typing import Dict, List, Optional, Set, Tuple
from enum import IntEnum
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
import numpy as np
import torch
import torch.nn as nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
......@@ -37,6 +39,66 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
]
class PreparePromptMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadataPerStage]
prompt_lens: List[int]
subquery_lens: List[int]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
multi_modal_input: Optional[torch.Tensor]
slot_mapping: List[int]
@classmethod
def empty(cls):
return PreparePromptMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
prompt_lens=[],
subquery_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
multi_modal_input=None,
slot_mapping=[],
)
class PrepareDecodeMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadata]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
slot_mapping: List[int]
@classmethod
def empty(cls):
return PrepareDecodeMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
slot_mapping=[],
)
# How batches are constructed.
class BatchType(IntEnum):
# Every batch is prefill.
PREFILL = 0
# Every batch is decode.
DECODE = 1
# Batch is a mixture of prefill and decode.
MIXED = 2
class ModelRunner:
def __init__(
......@@ -152,10 +214,7 @@ class ModelRunner:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
List[int], List[int], List[int], Set[LoRARequest],
torch.Tensor]:
assert len(seq_group_metadata_list) > 0
) -> PreparePromptMetadata:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
......@@ -169,6 +228,9 @@ class ModelRunner:
prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = []
if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty()
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
......@@ -178,7 +240,8 @@ class ModelRunner:
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and computed_block_nums is not None):
and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
......@@ -190,13 +253,8 @@ class ModelRunner:
# it contains output tokens.
prefill_end = min(seq_data.get_len(),
computed_len + token_chunk_size)
# TODO(sang): Rename it after chunked prefill is introduced.
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
prompt_len = len(prompt_tokens)
# Right now, the prefill_end is always same as the length of
# sequence. However, once chunked prefill is introduced, this
# assumption can be changed.
assert prefill_end == seq_data.get_len()
prompt_len = prefill_end
prompt_lens.append(prompt_len)
# NOTE: This only works for oooooooxxx style attention.
......@@ -206,6 +264,14 @@ class ModelRunner:
computed_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:]
prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None:
# Prefill has chunked before.
block_table = seq_group_metadata.block_tables[seq_id]
prefix_block_tables.append(block_table)
else:
# The first prefill.
prefix_block_tables.append([])
else:
prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this
......@@ -267,20 +333,8 @@ class ModelRunner:
max_subquery_len = max(subquery_lens)
max_prompt_len = max(prompt_lens)
num_prompt_tokens = len(input_tokens)
assert max_subquery_len > 0
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
lora_index_mapping = lora_index_mapping
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
......@@ -332,11 +386,8 @@ class ModelRunner:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
slot_mapping=slot_mapping,
prompt_lens=prompt_lens,
prompt_lens_tensor=prompt_lens_tensor,
num_prompt_tokens=num_prompt_tokens,
num_generation_tokens=0,
max_subquery_len=max_subquery_len,
max_context_len=None,
max_prompt_len=max_prompt_len,
......@@ -345,18 +396,25 @@ class ModelRunner:
context_lens=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests, multi_modal_input)
return PreparePromptMetadata(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
prompt_lens=prompt_lens,
subquery_lens=subquery_lens,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
multi_modal_input=multi_modal_input,
slot_mapping=slot_mapping,
)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0
) -> PrepareDecodeMetadata:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
......@@ -366,6 +424,9 @@ class ModelRunner:
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty()
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
......@@ -424,15 +485,6 @@ class ModelRunner:
lora_index_mapping.append(0)
batch_size = graph_batch_size
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
......@@ -440,9 +492,9 @@ class ModelRunner:
if use_captured_graph:
# When using cuda-graph all these tensors should be
# padded.
assert context_lens.shape[0] == input_tokens.shape[0]
assert context_lens.shape[0] == input_positions.shape[0]
assert context_lens.shape[0] == slot_mapping.shape[0]
assert context_lens.shape[0] == len(input_tokens)
assert context_lens.shape[0] == len(input_positions)
assert context_lens.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
......@@ -464,11 +516,8 @@ class ModelRunner:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping,
prompt_lens=None,
prompt_lens_tensor=None,
num_prompt_tokens=0,
num_generation_tokens=len(input_tokens),
max_subquery_len=None,
max_context_len=max_context_len,
max_prompt_len=None,
......@@ -477,10 +526,16 @@ class ModelRunner:
context_lens=context_lens,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata,
lora_index_mapping, lora_prompt_mapping, lora_requests)
return PrepareDecodeMetadata(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
slot_mapping=slot_mapping,
)
def _prepare_sample(
self,
......@@ -586,26 +641,66 @@ class ModelRunner:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[int], LoRAMapping, torch.Tensor]:
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
prefill_reqs = []
decode_reqs = []
for seq_group_meta in seq_group_metadata_list:
if seq_group_meta.is_prompt:
prefill_reqs.append(seq_group_meta)
else:
decode_reqs.append(seq_group_meta)
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests, multi_modal_input
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions, attn_metadata,
lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
subquery_lens = None
multi_modal_input = None
(
input_tokens,
input_positions,
prefill_attn_metadata,
prompt_lens,
subquery_lens,
lora_index_mapping,
lora_prompt_mapping,
lora_requests,
multi_modal_input,
slot_mapping,
) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
decode_input_positions,
decode_attn_metadata,
decode_lora_index_mapping,
decode_lora_prompt_mapping,
decode_lora_requests,
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(prompt_lens)
num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens)
# Coalesce tensors. Note that attn_metadata is currently not
# coalesced for simplicity.
input_tokens.extend(decode_input_tokens)
input_positions.extend(decode_input_positions)
slot_mapping.extend(decode_slot_mapping)
lora_index_mapping.extend(decode_lora_index_mapping)
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
lora_requests.update(decode_lora_requests)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
if self.lora_config:
lora_mapping = LoRAMapping(
lora_index_mapping,
......@@ -615,6 +710,16 @@ class ModelRunner:
lora_mapping = None
# Broadcast the metadata.
# If batch contains both prefill and decode, it sends 2 broadcasts.
# If it only contains 1 type, it triggers a single broadcast.
if (prefill_attn_metadata is not None
and decode_attn_metadata is not None):
batch_type = BatchType.MIXED
elif prefill_attn_metadata is not None:
batch_type = BatchType.PREFILL
else:
batch_type = BatchType.DECODE
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
......@@ -623,19 +728,49 @@ class ModelRunner:
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_input": multi_modal_input,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
"batch_type": batch_type,
}
metadata_dict.update(attn_metadata.asdict_zerocopy())
if prefill_attn_metadata is not None:
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
else:
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
# Broadcast decode attn metadata for mixed batch type.
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
# We can potentially reduce the overhead by coelescing tensors.
if batch_type == BatchType.MIXED:
assert decode_attn_metadata is not None
metadata_dict = decode_attn_metadata.asdict_zerocopy()
broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
slot_mapping = metadata_dict.pop("slot_mapping")
num_prefills = metadata_dict.pop("num_prefills")
selected_token_indices = metadata_dict.pop(
"selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input")
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
batch_type = metadata_dict.pop("batch_type")
# Create an attention metadata.
prefill_attn_metadata = None
decode_attn_metadata = None
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
prefill_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
......@@ -646,6 +781,23 @@ class ModelRunner:
perform_sampling=False,
)
# if it is a mixed batch, decode attn_metadata is broadcasted
# separately.
if batch_type == BatchType.MIXED:
metadata_dict = broadcast_tensor_dict(src=0)
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
attn_metadata = AttentionMetadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
prefill_metadata=prefill_attn_metadata,
decode_metadata=decode_attn_metadata,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping,
multi_modal_input)
......@@ -663,8 +815,10 @@ class ModelRunner:
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
# Execute the model.
if attn_metadata.use_cuda_graph:
# Currently cuda graph is only supported by the decode phase.
prefill_meta = attn_metadata.prefill_metadata
decode_meta = attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
......@@ -842,13 +996,10 @@ class ModelRunner:
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
# Create dummy attn_metadata.
attn_metadata = self.attn_backend.make_metadata(
decode_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping[:batch_size],
prompt_lens=None,
prompt_lens_tensor=None,
num_prompt_tokens=0,
num_generation_tokens=batch_size,
max_subquery_len=None,
max_context_len=self.max_context_len_to_capture,
max_prompt_len=None,
......@@ -857,6 +1008,14 @@ class ModelRunner:
context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
attn_metadata = AttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping[:batch_size],
prefill_metadata=None,
decode_metadata=decode_metadata,
kv_cache_dtype=self.kv_cache_dtype,
)
......@@ -950,8 +1109,8 @@ class CUDAGraphRunner:
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping,
"context_lens": attn_metadata.context_lens,
"block_tables": attn_metadata.block_tables,
"context_lens": attn_metadata.decode_metadata.context_lens,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
self.output_buffers = {"hidden_states": hidden_states}
return
......@@ -972,10 +1131,10 @@ class CUDAGraphRunner:
self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True)
self.input_buffers["context_lens"].copy_(attn_metadata.context_lens,
non_blocking=True)
self.input_buffers["block_tables"].copy_(attn_metadata.block_tables,
non_blocking=True)
self.input_buffers["context_lens"].copy_(
attn_metadata.decode_metadata.context_lens, non_blocking=True)
self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
# Run the graph.
self.graph.replay()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment