"vscode:/vscode.git/clone" did not exist on "e0d8c9ef838d0a7372a4807cd978e032bd26c572"
Unverified Commit bf98d2e3 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Support prefill overlap + Ensure no race condition (#5609)

parent e65b9f21
......@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
from __future__ import annotations
import logging
from collections import deque
from typing import TYPE_CHECKING, List, Optional
import torch
......@@ -204,6 +205,40 @@ class SchedulerDisaggregationPrefillMixin:
# Otherwise, it hangs under high concurrency
self.running_batch.batch_is_full = False
@torch.no_grad()
def event_loop_overlap_disagg_prefill(self):
self.result_queue = deque()
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.waiting_queue.extend(
self.disagg_prefill_pending_queue.pop_bootstrapped()
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
if self.last_batch:
tmp_batch, tmp_result = self.result_queue.popleft()
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
if len(self.disagg_prefill_inflight_queue) > 0:
self.process_disagg_prefill_inflight_queue()
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self.running_batch.batch_is_full = False
def process_batch_result_disagg_prefill(
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
) -> None:
......@@ -212,7 +247,26 @@ class SchedulerDisaggregationPrefillMixin:
Adapted from process_batch_result_prefill
"""
next_token_ids = result.next_token_ids.tolist()
(
logits_output,
next_token_ids,
extend_input_len_per_req,
extend_logprob_start_len_per_req,
bid,
) = (
result.logits_output,
result.next_token_ids,
result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req,
result.bid,
)
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if self.enable_overlap:
# wait
_, next_token_ids = self.tp_worker.resolve_batch_result(bid)
else:
next_token_ids = result.next_token_ids.tolist()
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
req: Req
......@@ -226,12 +280,8 @@ class SchedulerDisaggregationPrefillMixin:
# being chunked reqs' prefill is not finished
req.is_chunked -= 1
# TODO: Not sure if this is necessary
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
# We need to remove this for overlap schedule.
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
if self.enable_overlap:
self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
"""
......@@ -276,20 +326,37 @@ class SchedulerDisaggregationPrefillMixin:
# only finished requests to running_batch.
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req)
self.send_kv_chunk(self.chunked_req)
if (
self.enable_overlap
): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
self.chunked_req.tmp_end_idx = min(
len(self.chunked_req.fill_ids),
len(self.chunked_req.origin_input_ids),
)
else:
self.send_kv_chunk(self.chunked_req)
# chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
self.running_batch.batch_is_full = False
def send_kv_chunk(
self: Scheduler, req: Req, token_id: Optional[int] = None
self: Scheduler,
req: Req,
token_id: Optional[int] = None,
end_idx: Optional[int] = None,
) -> None:
"""
Send a prefilled chunk to the decode server
"""
page_size = self.token_to_kv_pool_allocator.page_size
start_idx = req.start_send_idx
end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
# the resolved length is not the same as fill_ids's length
end_idx = (
end_idx
if end_idx is not None
else min(len(req.fill_ids), len(req.origin_input_ids))
)
last_chunk = token_id is not None
if (not last_chunk) and (
......@@ -302,7 +369,7 @@ class SchedulerDisaggregationPrefillMixin:
req.start_send_idx = end_idx
kv_indices = (
self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
.cpu()
.numpy()
)
......
......@@ -539,6 +539,11 @@ class Req:
# The first output_id transferred from prefill instance.
self.transferred_output_id: Optional[int] = None
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
# This is because kv is not ready in `process_prefill_chunk`.
# We use `tmp_end_idx` to store the end index of the kv cache to send.
self.tmp_end_idx: int = -1
@property
def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids)
......
......@@ -2014,7 +2014,10 @@ def run_scheduler_process(
else:
scheduler.event_loop_normal()
elif disaggregation_mode == DisaggregationMode.PREFILL:
scheduler.event_loop_normal_disagg_prefill()
if scheduler.enable_overlap:
scheduler.event_loop_overlap_disagg_prefill()
else:
scheduler.event_loop_normal_disagg_prefill()
elif disaggregation_mode == DisaggregationMode.DECODE:
if scheduler.enable_overlap:
scheduler.event_loop_overlap_disagg_decode()
......
......@@ -388,8 +388,6 @@ class ServerArgs:
if self.disaggregation_mode == "prefill":
self.disable_cuda_graph = True
logger.warning("Cuda graph is disabled for prefill server")
self.disable_overlap_schedule = True
logger.warning("Overlap scheduler is disabled for prefill server")
elif self.disaggregation_mode == "decode":
self.disable_radix_cache = True
logger.warning("KV cache is forced as chunk cache for decode server")
......
prompt = [0] * 431
import json
import requests
prompt = """
According to CNBC's Faber, the investors present on the call interpreted this statement as an indication of an upcoming funding round. While speculative, Faber believes the funding round could be as large as $25 billion, and bestow a valuation of between $150 billion and $200 billion on xAI.
For the benefit of those who might not be aware, xAI recently acquired the social media platform X in an all-stock deal that valued the former at $80 billion and the latter at $33 billion, inclusive of $12 billion in liabilities. This meant that the deal bestowed a gross valuation of $45 billion on X before factoring in its debt load of $12 billion.
Bear in mind that Elon Musk took X (then called Twitter) private back in 2022 in a $44 billion deal. Since then, Musk has managed to stem X's cash bleed, with the company reportedly generating $1.2 billion in adjusted EBITDA in 2024.
According to the investors present on the call, xAI is currently generating around $1 billion in annual revenue. This contrasts sharply with the erstwhile muted expectations of many investors, who did not expect the startup to generate any material revenue this year.
Elsewhere, Faber also alludes to the fact that xAI is already working on its next big training supercluster, officially dubbed the Colossus 2, which is expected to eventually house as many as 1 million NVIDIA GPUs at a cost of between $35 billion and $40 billion.
Even though xAI's Grok LLM is already largely comparable with OpenAI's cutting-edge models, the Colossus 2 would significantly up the ante, and could feasibly challenge OpenAI's apex position in the AI sphere.
Give your honest take on the above text:
"""
response = requests.post(
"http://0.0.0.0:8000/generate",
json={"input_ids": [prompt] * 32, "sampling_params": {"temperature": 0}},
json={"text": prompt, "sampling_params": {"temperature": 0}},
)
# print("Response content (raw):", response.content)
response_json = response.json()
print(response_json["text"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment