Unverified Commit cf470fea authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Make token mapping non-blocking in the overlapped mode (#1740)

parent 45d5af24
...@@ -86,16 +86,15 @@ class TpModelWorkerClient: ...@@ -86,16 +86,15 @@ class TpModelWorkerClient:
@torch.inference_mode() @torch.inference_mode()
def forward_thread_func_(self): def forward_thread_func_(self):
while True: while True:
tic1 = time.time()
model_worker_batch, future_token_ids_ct = self.input_queue.get() model_worker_batch, future_token_ids_ct = self.input_queue.get()
# Resolve future tokens in the input # Resolve future tokens in the input
tic2 = time.time() input_ids = model_worker_batch.input_ids
resolved_input_ids = model_worker_batch.input_ids input_ids[:] = torch.where(
future_mask = resolved_input_ids < 0 input_ids < 0,
resolved_input_ids[future_mask] = self.future_token_ids_map[ self.future_token_ids_map[torch.clamp(-input_ids, min=0)],
-resolved_input_ids[future_mask] input_ids,
] )
# Run forward # Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation( logits_output, next_token_ids = self.worker.forward_batch_generation(
...@@ -119,15 +118,6 @@ class TpModelWorkerClient: ...@@ -119,15 +118,6 @@ class TpModelWorkerClient:
assert logits_output.next_token_logprobs is None, "Not supported" assert logits_output.next_token_logprobs is None, "Not supported"
self.output_queue.put((None, next_token_ids)) self.output_queue.put((None, next_token_ids))
if False:
tic3 = time.time()
self.acc_time_with_waiting += tic3 - tic1
self.acc_time_without_waiting += tic3 - tic2
if self.forward_queue.qsize() == 0:
logger.info(
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
)
def resulve_batch_result(self, bid: int): def resulve_batch_result(self, bid: int):
logits_output, next_token_ids = self.output_queue.get() logits_output, next_token_ids = self.output_queue.get()
return logits_output, next_token_ids return logits_output, next_token_ids
......
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}') kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment