Unverified Commit 67b4221a authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Core][5/N] Fully working chunked prefill e2e (#3884)

parent 63e7176f
...@@ -173,10 +173,18 @@ def broadcast_tensor_dict( ...@@ -173,10 +173,18 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list([metadata_list], torch.distributed.broadcast_object_list([metadata_list],
src=src, src=src,
group=group) group=group)
async_handles = []
for key, value in metadata_list: for key, value in metadata_list:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = tensor_dict[key] tensor = tensor_dict[key]
torch.distributed.broadcast(tensor, src=src, group=group) async_handles.append(
torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True))
for async_handle in async_handles:
async_handle.wait()
else: else:
recv_metadata_list = [None] recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list, torch.distributed.broadcast_object_list(recv_metadata_list,
......
...@@ -386,9 +386,8 @@ class EngineArgs: ...@@ -386,9 +386,8 @@ class EngineArgs:
'prompt latency) before scheduling next prompt.') 'prompt latency) before scheduling next prompt.')
parser.add_argument( parser.add_argument(
'--enable-chunked-prefill', '--enable-chunked-prefill',
type=bool, action='store_true',
default=False, help='If set, the prefill requests can be chunked based on the '
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens') 'max_num_batched_tokens')
parser.add_argument( parser.add_argument(
......
...@@ -633,6 +633,9 @@ class LLMEngine: ...@@ -633,6 +633,9 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens( seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size) scheduled_seq_group.token_chunk_size)
# If uncomputed tokens > 0, it means prefill is chunked.
# We don't need to process outputs in that case.
if seq_group.get_num_uncomputed_tokens() == 0:
self._process_sequence_group_outputs(seq_group, outputs) self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups. # Free the finished sequence groups.
......
...@@ -267,12 +267,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -267,12 +267,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1 added_tokens_mask = x > self.base_layer.org_vocab_size - 1
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) embedding_len = self.indices_len[3]
indices = self.embeddings_indices[1][:embedding_len].view_as(x)
full_lora_a_embeddings = F.embedding( full_lora_a_embeddings = F.embedding(
x + indices, x + indices,
self.lora_a_stacked_2d, self.lora_a_stacked_2d,
) )
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) indices = self.embeddings_indices[0][:embedding_len].view_as(x)
full_output = self.base_layer.forward( full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask)) x.add_(indices * added_tokens_mask))
......
...@@ -500,6 +500,7 @@ class SequenceGroup: ...@@ -500,6 +500,7 @@ class SequenceGroup:
def get_num_uncomputed_tokens(self) -> int: def get_num_uncomputed_tokens(self) -> int:
num_uncomputed_tokens = 0 num_uncomputed_tokens = 0
for seq in self.get_seqs(): for seq in self.get_seqs():
if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens return num_uncomputed_tokens
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment