Unverified Commit 25c73959 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix input logprob index (#9841)


Co-authored-by: default avatarSheng Shen <sheng.s@berkeley.edu>
parent f05c6873
...@@ -93,20 +93,21 @@ class SchedulerOutputProcessorMixin: ...@@ -93,20 +93,21 @@ class SchedulerOutputProcessorMixin:
# This updates radix so others can match # This updates radix so others can match
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
if req.return_logprob: if batch.return_logprob:
assert extend_logprob_start_len_per_req is not None assert extend_logprob_start_len_per_req is not None
assert extend_input_len_per_req is not None assert extend_input_len_per_req is not None
extend_logprob_start_len = extend_logprob_start_len_per_req[i] extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i] extend_input_len = extend_input_len_per_req[i]
num_input_logprobs = extend_input_len - extend_logprob_start_len num_input_logprobs = extend_input_len - extend_logprob_start_len
self.add_logprob_return_values( if req.return_logprob:
i, self.add_logprob_return_values(
req, i,
logprob_pt, req,
next_token_ids, logprob_pt,
num_input_logprobs, next_token_ids,
logits_output, num_input_logprobs,
) logits_output,
)
logprob_pt += num_input_logprobs logprob_pt += num_input_logprobs
if ( if (
...@@ -146,7 +147,7 @@ class SchedulerOutputProcessorMixin: ...@@ -146,7 +147,7 @@ class SchedulerOutputProcessorMixin:
skip_stream_req = req skip_stream_req = req
# Incrementally update input logprobs. # Incrementally update input logprobs.
if req.return_logprob: if batch.return_logprob:
extend_logprob_start_len = extend_logprob_start_len_per_req[i] extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i] extend_input_len = extend_input_len_per_req[i]
if extend_logprob_start_len < extend_input_len: if extend_logprob_start_len < extend_input_len:
...@@ -154,14 +155,15 @@ class SchedulerOutputProcessorMixin: ...@@ -154,14 +155,15 @@ class SchedulerOutputProcessorMixin:
num_input_logprobs = ( num_input_logprobs = (
extend_input_len - extend_logprob_start_len extend_input_len - extend_logprob_start_len
) )
self.add_input_logprob_return_values( if req.return_logprob:
i, self.add_input_logprob_return_values(
req, i,
logits_output, req,
logprob_pt, logits_output,
num_input_logprobs, logprob_pt,
last_prefill_chunk=False, num_input_logprobs,
) last_prefill_chunk=False,
)
logprob_pt += num_input_logprobs logprob_pt += num_input_logprobs
self.set_next_batch_sampling_info_done(batch) self.set_next_batch_sampling_info_done(batch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment