Unverified Commit cf470fea authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Make token mapping non-blocking in the overlapped mode (#1740)

parent 45d5af24
......@@ -86,16 +86,15 @@ class TpModelWorkerClient:
@torch.inference_mode()
def forward_thread_func_(self):
while True:
tic1 = time.time()
model_worker_batch, future_token_ids_ct = self.input_queue.get()
# Resolve future tokens in the input
tic2 = time.time()
resolved_input_ids = model_worker_batch.input_ids
future_mask = resolved_input_ids < 0
resolved_input_ids[future_mask] = self.future_token_ids_map[
-resolved_input_ids[future_mask]
]
input_ids = model_worker_batch.input_ids
input_ids[:] = torch.where(
input_ids < 0,
self.future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
# Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation(
......@@ -119,15 +118,6 @@ class TpModelWorkerClient:
assert logits_output.next_token_logprobs is None, "Not supported"
self.output_queue.put((None, next_token_ids))
if False:
tic3 = time.time()
self.acc_time_with_waiting += tic3 - tic1
self.acc_time_without_waiting += tic3 - tic2
if self.forward_queue.qsize() == 0:
logger.info(
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
)
def resulve_batch_result(self, bid: int):
logits_output, next_token_ids = self.output_queue.get()
return logits_output, next_token_ids
......
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment