Unverified Commit 67b4221a authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Core][5/N] Fully working chunked prefill e2e (#3884)

parent 63e7176f
...@@ -173,10 +173,18 @@ def broadcast_tensor_dict( ...@@ -173,10 +173,18 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list([metadata_list], torch.distributed.broadcast_object_list([metadata_list],
src=src, src=src,
group=group) group=group)
async_handles = []
for key, value in metadata_list: for key, value in metadata_list:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = tensor_dict[key] tensor = tensor_dict[key]
torch.distributed.broadcast(tensor, src=src, group=group) async_handles.append(
torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True))
for async_handle in async_handles:
async_handle.wait()
else: else:
recv_metadata_list = [None] recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list, torch.distributed.broadcast_object_list(recv_metadata_list,
......
...@@ -386,9 +386,8 @@ class EngineArgs: ...@@ -386,9 +386,8 @@ class EngineArgs:
'prompt latency) before scheduling next prompt.') 'prompt latency) before scheduling next prompt.')
parser.add_argument( parser.add_argument(
'--enable-chunked-prefill', '--enable-chunked-prefill',
type=bool, action='store_true',
default=False, help='If set, the prefill requests can be chunked based on the '
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens') 'max_num_batched_tokens')
parser.add_argument( parser.add_argument(
......
...@@ -633,7 +633,10 @@ class LLMEngine: ...@@ -633,7 +633,10 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens( seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size) scheduled_seq_group.token_chunk_size)
self._process_sequence_group_outputs(seq_group, outputs) # If uncomputed tokens > 0, it means prefill is chunked.
# We don't need to process outputs in that case.
if seq_group.get_num_uncomputed_tokens() == 0:
self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups. # Free the finished sequence groups.
self.scheduler.free_finished_seq_groups() self.scheduler.free_finished_seq_groups()
......
...@@ -267,12 +267,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -267,12 +267,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1 added_tokens_mask = x > self.base_layer.org_vocab_size - 1
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) embedding_len = self.indices_len[3]
indices = self.embeddings_indices[1][:embedding_len].view_as(x)
full_lora_a_embeddings = F.embedding( full_lora_a_embeddings = F.embedding(
x + indices, x + indices,
self.lora_a_stacked_2d, self.lora_a_stacked_2d,
) )
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) indices = self.embeddings_indices[0][:embedding_len].view_as(x)
full_output = self.base_layer.forward( full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask)) x.add_(indices * added_tokens_mask))
......
...@@ -500,7 +500,8 @@ class SequenceGroup: ...@@ -500,7 +500,8 @@ class SequenceGroup:
def get_num_uncomputed_tokens(self) -> int: def get_num_uncomputed_tokens(self) -> int:
num_uncomputed_tokens = 0 num_uncomputed_tokens = 0
for seq in self.get_seqs(): for seq in self.get_seqs():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens return num_uncomputed_tokens
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
......
import contextlib import contextlib
import time import time
from typing import Dict, List, Optional, Set, Tuple from enum import IntEnum
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig) SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
...@@ -37,6 +39,66 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ ...@@ -37,6 +39,66 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
] ]
class PreparePromptMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadataPerStage]
prompt_lens: List[int]
subquery_lens: List[int]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
multi_modal_input: Optional[torch.Tensor]
slot_mapping: List[int]
@classmethod
def empty(cls):
return PreparePromptMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
prompt_lens=[],
subquery_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
multi_modal_input=None,
slot_mapping=[],
)
class PrepareDecodeMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadata]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
slot_mapping: List[int]
@classmethod
def empty(cls):
return PrepareDecodeMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
slot_mapping=[],
)
# How batches are constructed.
class BatchType(IntEnum):
# Every batch is prefill.
PREFILL = 0
# Every batch is decode.
DECODE = 1
# Batch is a mixture of prefill and decode.
MIXED = 2
class ModelRunner: class ModelRunner:
def __init__( def __init__(
...@@ -152,10 +214,7 @@ class ModelRunner: ...@@ -152,10 +214,7 @@ class ModelRunner:
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], ) -> PreparePromptMetadata:
List[int], List[int], List[int], Set[LoRARequest],
torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
...@@ -169,6 +228,9 @@ class ModelRunner: ...@@ -169,6 +228,9 @@ class ModelRunner:
prefix_block_tables: List[List[int]] = [] prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = [] multi_modal_input_list: List[torch.Tensor] = []
if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty()
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
...@@ -178,7 +240,8 @@ class ModelRunner: ...@@ -178,7 +240,8 @@ class ModelRunner:
computed_block_nums = seq_group_metadata.computed_block_nums computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled and self.scheduler_config.chunked_prefill_enabled
and computed_block_nums is not None): and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError( raise RuntimeError(
"chunked prefill cannot be used with prefix caching " "chunked prefill cannot be used with prefix caching "
"now.") "now.")
...@@ -190,13 +253,8 @@ class ModelRunner: ...@@ -190,13 +253,8 @@ class ModelRunner:
# it contains output tokens. # it contains output tokens.
prefill_end = min(seq_data.get_len(), prefill_end = min(seq_data.get_len(),
computed_len + token_chunk_size) computed_len + token_chunk_size)
# TODO(sang): Rename it after chunked prefill is introduced.
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
prompt_len = len(prompt_tokens) prompt_len = prefill_end
# Right now, the prefill_end is always same as the length of
# sequence. However, once chunked prefill is introduced, this
# assumption can be changed.
assert prefill_end == seq_data.get_len()
prompt_lens.append(prompt_len) prompt_lens.append(prompt_len)
# NOTE: This only works for oooooooxxx style attention. # NOTE: This only works for oooooooxxx style attention.
...@@ -206,6 +264,14 @@ class ModelRunner: ...@@ -206,6 +264,14 @@ class ModelRunner:
computed_len = len(computed_block_nums) * self.block_size computed_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:] prompt_tokens = prompt_tokens[computed_len:]
prefix_block_tables.append(computed_block_nums) prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None:
# Prefill has chunked before.
block_table = seq_group_metadata.block_tables[seq_id]
prefix_block_tables.append(block_table)
else:
# The first prefill.
prefix_block_tables.append([])
else: else:
prefix_block_tables.append([]) prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this # Right now, prefill start is always 0. However, this
...@@ -267,20 +333,8 @@ class ModelRunner: ...@@ -267,20 +333,8 @@ class ModelRunner:
max_subquery_len = max(subquery_lens) max_subquery_len = max(subquery_lens)
max_prompt_len = max(prompt_lens) max_prompt_len = max(prompt_lens)
num_prompt_tokens = len(input_tokens)
assert max_subquery_len > 0 assert max_subquery_len > 0
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
lora_index_mapping = lora_index_mapping
context_lens_tensor = torch.tensor(context_lens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
...@@ -332,11 +386,8 @@ class ModelRunner: ...@@ -332,11 +386,8 @@ class ModelRunner:
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
slot_mapping=slot_mapping,
prompt_lens=prompt_lens, prompt_lens=prompt_lens,
prompt_lens_tensor=prompt_lens_tensor, prompt_lens_tensor=prompt_lens_tensor,
num_prompt_tokens=num_prompt_tokens,
num_generation_tokens=0,
max_subquery_len=max_subquery_len, max_subquery_len=max_subquery_len,
max_context_len=None, max_context_len=None,
max_prompt_len=max_prompt_len, max_prompt_len=max_prompt_len,
...@@ -345,18 +396,25 @@ class ModelRunner: ...@@ -345,18 +396,25 @@ class ModelRunner:
context_lens=context_lens_tensor, context_lens=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=False, use_cuda_graph=False,
kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, attn_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping, return PreparePromptMetadata(
lora_requests, multi_modal_input) input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
prompt_lens=prompt_lens,
subquery_lens=subquery_lens,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
multi_modal_input=multi_modal_input,
slot_mapping=slot_mapping,
)
def _prepare_decode( def _prepare_decode(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], ) -> PrepareDecodeMetadata:
List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
...@@ -366,6 +424,9 @@ class ModelRunner: ...@@ -366,6 +424,9 @@ class ModelRunner:
lora_prompt_mapping: List[int] = [] lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set() lora_requests: Set[LoRARequest] = set()
if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty()
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1 assert seq_group_metadata.token_chunk_size == 1
...@@ -424,15 +485,6 @@ class ModelRunner: ...@@ -424,15 +485,6 @@ class ModelRunner:
lora_index_mapping.append(0) lora_index_mapping.append(0)
batch_size = graph_batch_size batch_size = graph_batch_size
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens, context_lens = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
...@@ -440,9 +492,9 @@ class ModelRunner: ...@@ -440,9 +492,9 @@ class ModelRunner:
if use_captured_graph: if use_captured_graph:
# When using cuda-graph all these tensors should be # When using cuda-graph all these tensors should be
# padded. # padded.
assert context_lens.shape[0] == input_tokens.shape[0] assert context_lens.shape[0] == len(input_tokens)
assert context_lens.shape[0] == input_positions.shape[0] assert context_lens.shape[0] == len(input_positions)
assert context_lens.shape[0] == slot_mapping.shape[0] assert context_lens.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
...@@ -464,11 +516,8 @@ class ModelRunner: ...@@ -464,11 +516,8 @@ class ModelRunner:
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping,
prompt_lens=None, prompt_lens=None,
prompt_lens_tensor=None, prompt_lens_tensor=None,
num_prompt_tokens=0,
num_generation_tokens=len(input_tokens),
max_subquery_len=None, max_subquery_len=None,
max_context_len=max_context_len, max_context_len=max_context_len,
max_prompt_len=None, max_prompt_len=None,
...@@ -477,10 +526,16 @@ class ModelRunner: ...@@ -477,10 +526,16 @@ class ModelRunner:
context_lens=context_lens, context_lens=context_lens,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, attn_metadata, return PrepareDecodeMetadata(
lora_index_mapping, lora_prompt_mapping, lora_requests) input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
slot_mapping=slot_mapping,
)
def _prepare_sample( def _prepare_sample(
self, self,
...@@ -586,26 +641,66 @@ class ModelRunner: ...@@ -586,26 +641,66 @@ class ModelRunner:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[int], LoRAMapping, torch.Tensor]: Set[int], LoRAMapping, torch.Tensor]:
if self.is_driver_worker: if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or prefill_reqs = []
# all decodes. decode_reqs = []
is_prompt = seq_group_metadata_list[0].is_prompt for seq_group_meta in seq_group_metadata_list:
if seq_group_meta.is_prompt:
prefill_reqs.append(seq_group_meta)
else:
decode_reqs.append(seq_group_meta)
# Prepare input tensors. # Prepare input tensors.
if is_prompt: (
(input_tokens, input_positions, attn_metadata, prompt_lens, input_tokens,
subquery_lens, lora_index_mapping, lora_prompt_mapping, input_positions,
lora_requests, multi_modal_input prefill_attn_metadata,
) = self._prepare_prompt(seq_group_metadata_list) prompt_lens,
else: subquery_lens,
(input_tokens, input_positions, attn_metadata, lora_index_mapping,
lora_index_mapping, lora_prompt_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_decode(seq_group_metadata_list) lora_requests,
prompt_lens = [] multi_modal_input,
subquery_lens = None slot_mapping,
multi_modal_input = None ) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
decode_input_positions,
decode_attn_metadata,
decode_lora_index_mapping,
decode_lora_prompt_mapping,
decode_lora_requests,
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
sampling_metadata = self._prepare_sample(seq_group_metadata_list, sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens, prompt_lens,
subquery_lens) subquery_lens)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(prompt_lens)
num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens)
# Coalesce tensors. Note that attn_metadata is currently not
# coalesced for simplicity.
input_tokens.extend(decode_input_tokens)
input_positions.extend(decode_input_positions)
slot_mapping.extend(decode_slot_mapping)
lora_index_mapping.extend(decode_lora_index_mapping)
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
lora_requests.update(decode_lora_requests)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
if self.lora_config: if self.lora_config:
lora_mapping = LoRAMapping( lora_mapping = LoRAMapping(
lora_index_mapping, lora_index_mapping,
...@@ -615,6 +710,16 @@ class ModelRunner: ...@@ -615,6 +710,16 @@ class ModelRunner:
lora_mapping = None lora_mapping = None
# Broadcast the metadata. # Broadcast the metadata.
# If batch contains both prefill and decode, it sends 2 broadcasts.
# If it only contains 1 type, it triggers a single broadcast.
if (prefill_attn_metadata is not None
and decode_attn_metadata is not None):
batch_type = BatchType.MIXED
elif prefill_attn_metadata is not None:
batch_type = BatchType.PREFILL
else:
batch_type = BatchType.DECODE
metadata_dict = { metadata_dict = {
"input_tokens": input_tokens, "input_tokens": input_tokens,
"input_positions": input_positions, "input_positions": input_positions,
...@@ -623,19 +728,49 @@ class ModelRunner: ...@@ -623,19 +728,49 @@ class ModelRunner:
"lora_requests": lora_requests, "lora_requests": lora_requests,
"lora_mapping": lora_mapping, "lora_mapping": lora_mapping,
"multi_modal_input": multi_modal_input, "multi_modal_input": multi_modal_input,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
"batch_type": batch_type,
} }
metadata_dict.update(attn_metadata.asdict_zerocopy()) if prefill_attn_metadata is not None:
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
else:
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0) broadcast_tensor_dict(metadata_dict, src=0)
# Broadcast decode attn metadata for mixed batch type.
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
# We can potentially reduce the overhead by coelescing tensors.
if batch_type == BatchType.MIXED:
assert decode_attn_metadata is not None
metadata_dict = decode_attn_metadata.asdict_zerocopy()
broadcast_tensor_dict(metadata_dict, src=0)
else: else:
metadata_dict = broadcast_tensor_dict(src=0) metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens") input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions") input_positions = metadata_dict.pop("input_positions")
slot_mapping = metadata_dict.pop("slot_mapping")
num_prefills = metadata_dict.pop("num_prefills")
selected_token_indices = metadata_dict.pop( selected_token_indices = metadata_dict.pop(
"selected_token_indices") "selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping") lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests") lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input") multi_modal_input = metadata_dict.pop("multi_modal_input")
attn_metadata = self.attn_backend.make_metadata(**metadata_dict) num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
batch_type = metadata_dict.pop("batch_type")
# Create an attention metadata.
prefill_attn_metadata = None
decode_attn_metadata = None
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
prefill_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
seq_groups=None, seq_groups=None,
seq_data=None, seq_data=None,
...@@ -646,6 +781,23 @@ class ModelRunner: ...@@ -646,6 +781,23 @@ class ModelRunner:
perform_sampling=False, perform_sampling=False,
) )
# if it is a mixed batch, decode attn_metadata is broadcasted
# separately.
if batch_type == BatchType.MIXED:
metadata_dict = broadcast_tensor_dict(src=0)
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
attn_metadata = AttentionMetadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
prefill_metadata=prefill_attn_metadata,
decode_metadata=decode_attn_metadata,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata, return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping, sampling_metadata, lora_requests, lora_mapping,
multi_modal_input) multi_modal_input)
...@@ -663,8 +815,10 @@ class ModelRunner: ...@@ -663,8 +815,10 @@ class ModelRunner:
if self.lora_config: if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping) self.set_active_loras(lora_requests, lora_mapping)
# Execute the model. # Currently cuda graph is only supported by the decode phase.
if attn_metadata.use_cuda_graph: prefill_meta = attn_metadata.prefill_metadata
decode_meta = attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
graph_batch_size = input_tokens.shape[0] graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size] model_executable = self.graph_runners[graph_batch_size]
else: else:
...@@ -842,13 +996,10 @@ class ModelRunner: ...@@ -842,13 +996,10 @@ class ModelRunner:
# memory usage of CUDA graph. # memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list): for batch_size in reversed(batch_size_capture_list):
# Create dummy attn_metadata. # Create dummy attn_metadata.
attn_metadata = self.attn_backend.make_metadata( decode_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping[:batch_size],
prompt_lens=None, prompt_lens=None,
prompt_lens_tensor=None, prompt_lens_tensor=None,
num_prompt_tokens=0,
num_generation_tokens=batch_size,
max_subquery_len=None, max_subquery_len=None,
max_context_len=self.max_context_len_to_capture, max_context_len=self.max_context_len_to_capture,
max_prompt_len=None, max_prompt_len=None,
...@@ -857,6 +1008,14 @@ class ModelRunner: ...@@ -857,6 +1008,14 @@ class ModelRunner:
context_lens=context_lens[:batch_size], context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size], block_tables=block_tables[:batch_size],
use_cuda_graph=True, use_cuda_graph=True,
)
attn_metadata = AttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping[:batch_size],
prefill_metadata=None,
decode_metadata=decode_metadata,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
) )
...@@ -950,8 +1109,8 @@ class CUDAGraphRunner: ...@@ -950,8 +1109,8 @@ class CUDAGraphRunner:
"positions": positions, "positions": positions,
"kv_caches": kv_caches, "kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping, "slot_mapping": attn_metadata.slot_mapping,
"context_lens": attn_metadata.context_lens, "context_lens": attn_metadata.decode_metadata.context_lens,
"block_tables": attn_metadata.block_tables, "block_tables": attn_metadata.decode_metadata.block_tables,
} }
self.output_buffers = {"hidden_states": hidden_states} self.output_buffers = {"hidden_states": hidden_states}
return return
...@@ -972,10 +1131,10 @@ class CUDAGraphRunner: ...@@ -972,10 +1131,10 @@ class CUDAGraphRunner:
self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True) non_blocking=True)
self.input_buffers["context_lens"].copy_(attn_metadata.context_lens, self.input_buffers["context_lens"].copy_(
non_blocking=True) attn_metadata.decode_metadata.context_lens, non_blocking=True)
self.input_buffers["block_tables"].copy_(attn_metadata.block_tables, self.input_buffers["block_tables"].copy_(
non_blocking=True) attn_metadata.decode_metadata.block_tables, non_blocking=True)
# Run the graph. # Run the graph.
self.graph.replay() self.graph.replay()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment