Unverified Commit b88ea90d authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix bugs of `logprobs_nums` (#1548)

parent 99ec439d
......@@ -748,6 +748,8 @@ class ScheduleBatch:
self.top_logprobs_nums = [
self.top_logprobs_nums[i] for i in unfinished_indices
]
else:
self.top_logprobs_nums = None
self.has_stream = any(req.stream for req in self.reqs)
self.sampling_info.filter_batch(unfinished_indices, new_indices)
......@@ -758,13 +760,11 @@ class ScheduleBatch:
# needs to be called with pre-merged Batch.reqs.
self.sampling_info.merge_batch(other.sampling_info)
self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat(
[self.req_pool_indices, other.req_pool_indices]
)
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
self.out_cache_loc = None
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob and other.return_logprob:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
elif self.return_logprob:
......@@ -772,6 +772,8 @@ class ScheduleBatch:
elif other.return_logprob:
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.has_stream = any(req.stream for req in self.reqs)
self.reqs.extend(other.reqs)
self.return_logprob = self.return_logprob or other.return_logprob
def get_model_worker_batch(self):
if self.forward_mode.is_decode():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment