Unverified Commit 66cc3fa5 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Model Runner V2] Multiple prompt logprobs support (#39937)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
Signed-off-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
parent 6d85b36a
......@@ -17,13 +17,14 @@ class PromptLogprobsWorker:
self.max_num_reqs = max_num_reqs
self.uses_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
self.num_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=np.int32)
# req_idx -> list of in-progress LogprobsTensors
self.in_progress_prompt_logprobs: dict[str, list[LogprobsTensors]] = {}
def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams):
# For now, only support prompt logprobs for the prompt tokens (not top-k).
uses_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.uses_prompt_logprobs[req_idx] = uses_prompt_logprobs
self.num_prompt_logprobs[req_idx] = sampling_params.prompt_logprobs or 0
if uses_prompt_logprobs:
self.in_progress_prompt_logprobs[req_id] = []
......@@ -52,6 +53,7 @@ class PromptLogprobsWorker:
# Common case: No request asks for prompt logprobs.
return {}
num_prompt_logprobs = self.num_prompt_logprobs[idx_mapping_np]
prompt_lens = prompt_lens[idx_mapping_np]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs.
......@@ -64,6 +66,14 @@ class PromptLogprobsWorker:
if not np.any(needs_prompt_logprobs):
return {}
# get the maximum number in this batch
requested_num_prompt_logprobs = num_prompt_logprobs[needs_prompt_logprobs]
max_num_prompt_logprobs = (
-1
if np.any(requested_num_prompt_logprobs == -1)
else int(requested_num_prompt_logprobs.max())
)
# Get the prompt logprobs token_ids.
prompt_logprobs_token_ids = get_prompt_logprobs_token_ids(
input_batch.num_tokens,
......@@ -72,45 +82,53 @@ class PromptLogprobsWorker:
num_computed_tokens,
all_token_ids,
)
# Compute the prompt logprobs.
prompt_logprobs, prompt_ranks = compute_prompt_logprobs_with_chunking(
prompt_logprobs_token_ids,
hidden_states[: input_batch.num_tokens],
logits_fn,
prompt_token_ids, prompt_logprobs, prompt_ranks = (
compute_prompt_logprobs_with_chunking(
prompt_logprobs_token_ids,
hidden_states[: input_batch.num_tokens],
logits_fn,
max_num_prompt_logprobs,
)
)
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
is_prompt_chunked = pos_after_step < prompt_lens
query_start_loc_np = input_batch.query_start_loc_np
prompt_token_ids = prompt_logprobs_token_ids.unsqueeze(-1)
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
req_is_prompt_chunked = is_prompt_chunked[i]
start_idx = query_start_loc_np[i]
end_idx = query_start_loc_np[i + 1]
assert start_idx < end_idx, (
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
)
if not is_prompt_chunked[i]:
if not req_is_prompt_chunked:
end_idx -= 1
logprobs = LogprobsTensors(
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
logprobs=prompt_logprobs[start_idx:end_idx],
selected_token_ranks=prompt_ranks[start_idx:end_idx],
# no logprobs if start_idx >= end_idx
logprobs = (
None
if start_idx >= end_idx
else LogprobsTensors(
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
logprobs=prompt_logprobs[start_idx:end_idx],
selected_token_ranks=prompt_ranks[start_idx:end_idx],
)
)
prompt_logprobs_list = self.in_progress_prompt_logprobs[req_id]
if is_prompt_chunked[i]:
# Prompt is chunked. Do not return the logprobs yet.
if logprobs is not None and (req_is_prompt_chunked or prompt_logprobs_list):
prompt_logprobs_list.append(logprobs)
if req_is_prompt_chunked:
# Prompt is chunked. Do not return the logprobs yet.
continue
if prompt_logprobs_list:
# Merge the in-progress logprobs.
prompt_logprobs_list.append(logprobs)
logprobs = LogprobsTensors(
logprob_token_ids=torch.cat(
[x.logprob_token_ids for x in prompt_logprobs_list]
......@@ -122,6 +140,9 @@ class PromptLogprobsWorker:
)
prompt_logprobs_list.clear()
if logprobs is None:
continue
prompt_logprobs_dict[req_id] = logprobs
return prompt_logprobs_dict
......@@ -184,10 +205,12 @@ def compute_prompt_logprobs_with_chunking(
prompt_token_ids: torch.Tensor,
prompt_hidden_states: torch.Tensor,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
num_prompt_logprobs: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Since materializing the full prompt logits can take too much memory,
# we compute it in chunks.
CHUNK_SIZE = 1024
token_ids = []
logprobs = []
ranks = []
prompt_token_ids = prompt_token_ids.to(torch.int64)
......@@ -195,14 +218,21 @@ def compute_prompt_logprobs_with_chunking(
end_idx = start_idx + CHUNK_SIZE
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
requested_num_prompt_logprobs = (
prompt_logits.shape[-1]
if num_prompt_logprobs == -1
else num_prompt_logprobs
)
prompt_logprobs = compute_topk_logprobs(
prompt_logits,
0, # num_logprobs
requested_num_prompt_logprobs,
prompt_token_ids[start_idx:end_idx],
)
token_ids.append(prompt_logprobs.logprob_token_ids)
logprobs.append(prompt_logprobs.logprobs)
ranks.append(prompt_logprobs.selected_token_ranks)
token_ids = torch.cat(token_ids, dim=0) if len(token_ids) > 1 else token_ids[0]
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
return logprobs, ranks
return token_ids, logprobs, ranks
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment