Unverified Commit 4cb9aaed authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Fix logprobs with logprob_start_len (#193)

parent 9de9a468
...@@ -432,9 +432,14 @@ class ModelRpcServer(rpyc.Service): ...@@ -432,9 +432,14 @@ class ModelRpcServer(rpyc.Service):
req.logprob = logprobs[pt : pt + req.extend_input_len - 1] req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
req.normalized_logprob = normalized_logprobs[i] req.normalized_logprob = normalized_logprobs[i]
token_ids = req.input_ids + [next_token_ids[i]] # If logprob_start_len > 0, then first logprob_start_len prompt tokens
token_logprobs = [None] + req.logprob + [last_logprobs[i]] # will be ignored.
prompt_token_len = len(req.logprob)
token_ids = req.input_ids[-prompt_token_len :] + [next_token_ids[i]]
token_logprobs = req.logprob + [last_logprobs[i]]
req.token_logprob = list(zip(token_ids, token_logprobs)) req.token_logprob = list(zip(token_ids, token_logprobs))
if req.logprob_start_len == 0:
req.token_logprob = [(req.input_ids[0], None)] + req.token_logprob
pt += req.extend_input_len pt += req.extend_input_len
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment