Unverified Commit b6aad70a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Fix] Fix the case where prompt_len = 0 (#1593)

parent 551a3a9d
......@@ -626,9 +626,11 @@ class Scheduler:
else:
logits_output = None
if self.tokenizer is not None:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
next_token_ids = torch.full(
(batch.batch_size(),), self.tokenizer.eos_token_id
)
else:
next_token_ids = [0] * len(batch.reqs)
next_token_ids = torch.full((batch.batch_size(),), 0)
return logits_output, next_token_ids
else: # embedding or reward model
assert batch.extend_num_tokens != 0
......
......@@ -526,7 +526,7 @@ class TokenizerManager:
async with self.model_update_lock:
# wait for the previous generation requests to finish
while len(self.rid_to_state) > 0:
await asyncio.sleep(0)
await asyncio.sleep(0.001)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
result = await self.model_update_result
......
......@@ -624,6 +624,6 @@ def broadcast_pyobj(
tensor_data = torch.empty(size, dtype=torch.uint8)
dist.broadcast(tensor_data, src=0, group=dist_group)
serialized_data = bytes(tensor_data.tolist())
serialized_data = bytes(tensor_data.cpu().numpy())
data = pickle.loads(serialized_data)
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment