Unverified Commit 8ea0c275 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Minor code cleanup for _get_prompt_logprobs_dict (#23064)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 0fc8fa75
......@@ -1722,7 +1722,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],
scheduler_output,
scheduler_output.num_scheduled_tokens,
)
# Get the valid generated tokens.
......@@ -2064,7 +2064,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _get_prompt_logprobs_dict(
self,
hidden_states: torch.Tensor,
scheduler_output: "SchedulerOutput",
num_scheduled_tokens: dict[str, int],
) -> dict[str, Optional[LogprobsTensors]]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
if not num_prompt_logprobs_dict:
......@@ -2077,8 +2077,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# maintainable loop over optimal performance.
completed_prefill_reqs = []
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_tokens = num_scheduled_tokens[req_id]
# Get metadata for this request.
request = self.requests[req_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment