Unverified Commit 7ce36068 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Faster overlap mode scheduler (#1738)

parent efb099cd
...@@ -55,7 +55,7 @@ class TpModelWorkerClient: ...@@ -55,7 +55,7 @@ class TpModelWorkerClient:
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device (self.max_running_requests * 5,), dtype=torch.int32, device=self.device
) )
# Launch a thread # Launch threads
self.input_queue = Queue() self.input_queue = Queue()
self.output_queue = Queue() self.output_queue = Queue()
self.forward_stream = torch.cuda.Stream() self.forward_stream = torch.cuda.Stream()
...@@ -64,6 +64,12 @@ class TpModelWorkerClient: ...@@ -64,6 +64,12 @@ class TpModelWorkerClient:
) )
self.forward_thread.start() self.forward_thread.start()
self.copy_queue = Queue()
self.copy_thread = threading.Thread(
target=self.copy_thread_func,
)
self.copy_thread.start()
def get_worker_info(self): def get_worker_info(self):
return self.worker.get_worker_info() return self.worker.get_worker_info()
...@@ -86,7 +92,10 @@ class TpModelWorkerClient: ...@@ -86,7 +92,10 @@ class TpModelWorkerClient:
@torch.inference_mode() @torch.inference_mode()
def forward_thread_func_(self): def forward_thread_func_(self):
while True: while True:
self.has_inflight_batch = False
model_worker_batch, future_token_ids_ct = self.input_queue.get() model_worker_batch, future_token_ids_ct = self.input_queue.get()
self.has_inflight_batch = True
self.launch_event = threading.Event()
# Resolve future tokens in the input # Resolve future tokens in the input
input_ids = model_worker_batch.input_ids input_ids = model_worker_batch.input_ids
...@@ -100,6 +109,7 @@ class TpModelWorkerClient: ...@@ -100,6 +109,7 @@ class TpModelWorkerClient:
logits_output, next_token_ids = self.worker.forward_batch_generation( logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
self.launch_event.set()
# Update the future token ids map # Update the future token ids map
bs = len(model_worker_batch.seq_lens) bs = len(model_worker_batch.seq_lens)
...@@ -113,13 +123,23 @@ class TpModelWorkerClient: ...@@ -113,13 +123,23 @@ class TpModelWorkerClient:
torch.int32 torch.int32
) )
# Set the result next_token_ids = next_token_ids.to("cpu", non_blocking=True)
next_token_ids = next_token_ids.tolist() copy_event = torch.cuda.Event(blocking=True)
assert logits_output.next_token_logprobs is None, "Not supported" copy_event.record()
self.output_queue.put((None, next_token_ids)) self.copy_queue.put((copy_event, next_token_ids))
def copy_thread_func(self):
while True:
copy_event, next_token_ids = self.copy_queue.get()
while not copy_event.query():
time.sleep(1e-5)
self.output_queue.put((None, next_token_ids.tolist()))
def resulve_batch_result(self, bid: int): def resulve_batch_result(self, bid: int):
logits_output, next_token_ids = self.output_queue.get() logits_output, next_token_ids = self.output_queue.get()
if self.has_inflight_batch:
# Wait until the batch is launched
self.launch_event.wait()
return logits_output, next_token_ids return logits_output, next_token_ids
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment