Unverified Commit ceba0ce4 authored by strgrb's avatar strgrb Committed by GitHub
Browse files

support return logprobs for pipeline (#7356)


Co-authored-by: default avatarZhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
parent 1ab6be1b
...@@ -812,11 +812,28 @@ class Scheduler( ...@@ -812,11 +812,28 @@ class Scheduler(
result.next_token_ids, result.next_token_ids,
result.bid, result.bid,
) )
pp_outputs = PPProxyTensors( if self.cur_batch.return_logprob:
{ pp_outputs = PPProxyTensors(
"next_token_ids": next_token_ids, {
} "next_token_ids": next_token_ids,
) "extend_input_len_per_req": result.extend_input_len_per_req,
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
}
| (
{
f"logits_output.{k}": v
for k, v in result.logits_output.__dict__.items()
}
if result.logits_output is not None
else {}
)
)
else:
pp_outputs = PPProxyTensors(
{
"next_token_ids": next_token_ids,
}
)
# send the output from the last round to let the next stage worker run post processing # send the output from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict( self.pp_group.send_tensor_dict(
pp_outputs.tensors, pp_outputs.tensors,
...@@ -833,12 +850,25 @@ class Scheduler( ...@@ -833,12 +850,25 @@ class Scheduler(
) )
) )
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"] mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
logits_output_args = {
k[len("logits_output.") :]: v
for k, v in next_pp_outputs.tensors.items()
if k.startswith("logits_output.")
}
if len(logits_output_args) > 0:
logits_output = LogitsProcessorOutput(**logits_output_args)
else:
logits_output = None
output_result = GenerationBatchResult( output_result = GenerationBatchResult(
logits_output=None, logits_output=logits_output,
pp_hidden_states_proxy_tensors=None, pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"], next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=None, extend_input_len_per_req=next_pp_outputs.tensors.get(
extend_logprob_start_len_per_req=None, "extend_input_len_per_req", None
),
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
"extend_logprob_start_len_per_req", None
),
bid=bids[next_mb_id], bid=bids[next_mb_id],
can_run_cuda_graph=result.can_run_cuda_graph, can_run_cuda_graph=result.can_run_cuda_graph,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment