"tutorials/vscode:/vscode.git/clone" did not exist on "09a1a2f89354ec5e1c15fd8418aaa4affab53c2e"
Unverified Commit ceba0ce4 authored by strgrb's avatar strgrb Committed by GitHub
Browse files

support return logprobs for pipeline (#7356)


Co-authored-by: default avatarZhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
parent 1ab6be1b
......@@ -812,11 +812,28 @@ class Scheduler(
result.next_token_ids,
result.bid,
)
pp_outputs = PPProxyTensors(
{
"next_token_ids": next_token_ids,
}
)
if self.cur_batch.return_logprob:
pp_outputs = PPProxyTensors(
{
"next_token_ids": next_token_ids,
"extend_input_len_per_req": result.extend_input_len_per_req,
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
}
| (
{
f"logits_output.{k}": v
for k, v in result.logits_output.__dict__.items()
}
if result.logits_output is not None
else {}
)
)
else:
pp_outputs = PPProxyTensors(
{
"next_token_ids": next_token_ids,
}
)
# send the output from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
......@@ -833,12 +850,25 @@ class Scheduler(
)
)
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
logits_output_args = {
k[len("logits_output.") :]: v
for k, v in next_pp_outputs.tensors.items()
if k.startswith("logits_output.")
}
if len(logits_output_args) > 0:
logits_output = LogitsProcessorOutput(**logits_output_args)
else:
logits_output = None
output_result = GenerationBatchResult(
logits_output=None,
logits_output=logits_output,
pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=None,
extend_logprob_start_len_per_req=None,
extend_input_len_per_req=next_pp_outputs.tensors.get(
"extend_input_len_per_req", None
),
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
"extend_logprob_start_len_per_req", None
),
bid=bids[next_mb_id],
can_run_cuda_graph=result.can_run_cuda_graph,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment