"vscode:/vscode.git/clone" did not exist on "99b0f687b48729084fd0462145b4e0c3e9fac383"
Unverified Commit b997a18d authored by yichuan~'s avatar yichuan~ Committed by GitHub
Browse files

[Feat]Add support for optional start len of logprobs (#1035)


Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: default avatarLiangsheng Yin <hnyls2002@gmail.com>
parent d8627ed1
...@@ -55,6 +55,9 @@ class LogitsMetadata: ...@@ -55,6 +55,9 @@ class LogitsMetadata:
extend_start_loc: Optional[torch.Tensor] = None extend_start_loc: Optional[torch.Tensor] = None
top_logprobs_nums: Optional[List[int]] = None top_logprobs_nums: Optional[List[int]] = None
extend_seq_lens_cpu: List[int] = None
logprob_start_lens_cpu: List[int] = None
@classmethod @classmethod
def from_input_metadata(cls, input_metadata: InputMetadata): def from_input_metadata(cls, input_metadata: InputMetadata):
return cls( return cls(
...@@ -63,6 +66,8 @@ class LogitsMetadata: ...@@ -63,6 +66,8 @@ class LogitsMetadata:
extend_start_loc=input_metadata.extend_start_loc, extend_start_loc=input_metadata.extend_start_loc,
return_logprob=input_metadata.return_logprob, return_logprob=input_metadata.return_logprob,
top_logprobs_nums=input_metadata.top_logprobs_nums, top_logprobs_nums=input_metadata.top_logprobs_nums,
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
logprob_start_lens_cpu=input_metadata.logprob_start_lens_cpu,
) )
...@@ -75,12 +80,16 @@ class LogitsProcessor(nn.Module): ...@@ -75,12 +80,16 @@ class LogitsProcessor(nn.Module):
) )
def _get_normalized_prompt_logprobs( def _get_normalized_prompt_logprobs(
self, input_token_logprobs, logits_metadata: LogitsMetadata self,
input_token_logprobs: torch.Tensor,
cum_start_len0: torch.Tensor,
cum_start_len1: torch.Tensor,
logits_metadata: LogitsMetadata,
): ):
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32) logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
start = logits_metadata.extend_start_loc.clone() start = logits_metadata.extend_start_loc.clone() - cum_start_len0
end = start + logits_metadata.extend_seq_lens - 2 end = start + logits_metadata.extend_seq_lens - 2 - cum_start_len1
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1) start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1) end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
sum_logp = ( sum_logp = (
...@@ -93,7 +102,7 @@ class LogitsProcessor(nn.Module): ...@@ -93,7 +102,7 @@ class LogitsProcessor(nn.Module):
return normalized_prompt_logprobs return normalized_prompt_logprobs
@staticmethod @staticmethod
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata): def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
if logits_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
output_top_logprobs = [] output_top_logprobs = []
max_k = max(logits_metadata.top_logprobs_nums) max_k = max(logits_metadata.top_logprobs_nums)
...@@ -107,7 +116,7 @@ class LogitsProcessor(nn.Module): ...@@ -107,7 +116,7 @@ class LogitsProcessor(nn.Module):
# TODO: vectorize the code below # TODO: vectorize the code below
input_top_logprobs, output_top_logprobs = [], [] input_top_logprobs, output_top_logprobs = [], []
pt = 0 pt = 0
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist() extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu
max_k = max(logits_metadata.top_logprobs_nums) max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1) ret = all_logprobs.topk(max_k, dim=1)
...@@ -115,26 +124,30 @@ class LogitsProcessor(nn.Module): ...@@ -115,26 +124,30 @@ class LogitsProcessor(nn.Module):
indices = ret.indices.tolist() indices = ret.indices.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu): for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
start_len = logits_metadata.logprob_start_lens_cpu[i]
pruned_len = extend_seq_len - start_len
if extend_seq_len == 0: if extend_seq_len == 0:
input_top_logprobs.append([]) input_top_logprobs.append([])
output_top_logprobs.append([]) output_top_logprobs.append([])
continue continue
k = logits_metadata.top_logprobs_nums[i] k = logits_metadata.top_logprobs_nums[i]
input_top_logprobs.append( input_top_logprobs.append(
[ [
list(zip(values[pt + j][:k], indices[pt + j][:k])) list(zip(values[pt + j][:k], indices[pt + j][:k]))
for j in range(extend_seq_len - 1) for j in range(pruned_len - 1)
] ]
) )
output_top_logprobs.append( output_top_logprobs.append(
list( list(
zip( zip(
values[pt + extend_seq_len - 1][:k], values[pt + pruned_len - 1][:k],
indices[pt + extend_seq_len - 1][:k], indices[pt + pruned_len - 1][:k],
) )
) )
) )
pt += extend_seq_len pt += pruned_len
return input_top_logprobs, output_top_logprobs return input_top_logprobs, output_top_logprobs
...@@ -205,7 +218,23 @@ class LogitsProcessor(nn.Module): ...@@ -205,7 +218,23 @@ class LogitsProcessor(nn.Module):
output_top_logprobs=output_top_logprobs, output_top_logprobs=output_top_logprobs,
) )
else: else:
all_logits = torch.matmul(hidden_states, weight.T) pt, states, pruned_input_ids = 0, [], []
for i, extend_len in enumerate(logits_metadata.extend_seq_lens_cpu):
start_len = logits_metadata.logprob_start_lens_cpu[i]
states.append(hidden_states[pt + start_len : pt + extend_len])
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
pt += extend_len
states = torch.cat(states, dim=0)
pruned_input_ids = torch.cat(pruned_input_ids, dim=0)
cum_start_len1 = torch.tensor(
logits_metadata.logprob_start_lens_cpu, device="cuda"
).cumsum(0)
cum_start_len0 = torch.zeros_like(cum_start_len1)
cum_start_len0[1:] = cum_start_len1[:-1]
all_logits = torch.matmul(states, weight.T)
if self.do_tensor_parallel_all_gather: if self.do_tensor_parallel_all_gather:
all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float() all_logits = all_logits[:, : self.config.vocab_size].float()
...@@ -230,19 +259,25 @@ class LogitsProcessor(nn.Module): ...@@ -230,19 +259,25 @@ class LogitsProcessor(nn.Module):
else: else:
input_top_logprobs = output_top_logprobs = None input_top_logprobs = output_top_logprobs = None
last_logprobs = all_logprobs[last_index] last_logprobs = all_logprobs[last_index - cum_start_len1]
# Compute the logprobs and normalized logprobs for the prefill tokens. # Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation. # Note that we pad a zero at the end of each sequence for easy computation.
input_token_logprobs = all_logprobs[ input_token_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"), torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), torch.cat([pruned_input_ids[1:], torch.tensor([0], device="cuda")]),
] ]
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
input_token_logprobs, logits_metadata input_token_logprobs,
cum_start_len0,
cum_start_len1,
logits_metadata,
) )
# Remove the last token logprob for the prefill tokens.
input_token_logprobs = input_token_logprobs[:-1]
return LogitProcessorOutput( return LogitProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=last_logprobs, next_token_logprobs=last_logprobs,
......
...@@ -75,7 +75,7 @@ class GenerateReqInput: ...@@ -75,7 +75,7 @@ class GenerateReqInput:
if self.return_logprob is None: if self.return_logprob is None:
self.return_logprob = False self.return_logprob = False
if self.logprob_start_len is None: if self.logprob_start_len is None:
self.logprob_start_len = 0 self.logprob_start_len = -1
if self.top_logprobs_num is None: if self.top_logprobs_num is None:
self.top_logprobs_num = 0 self.top_logprobs_num = 0
else: else:
...@@ -141,7 +141,7 @@ class GenerateReqInput: ...@@ -141,7 +141,7 @@ class GenerateReqInput:
self.return_logprob = [self.return_logprob] * num self.return_logprob = [self.return_logprob] * num
if self.logprob_start_len is None: if self.logprob_start_len is None:
self.logprob_start_len = [0] * num self.logprob_start_len = [-1] * num
elif not isinstance(self.logprob_start_len, list): elif not isinstance(self.logprob_start_len, list):
self.logprob_start_len = [self.logprob_start_len] * num self.logprob_start_len = [self.logprob_start_len] * num
......
...@@ -195,6 +195,9 @@ class TokenizerManager: ...@@ -195,6 +195,9 @@ class TokenizerManager:
if not_use_index if not_use_index
else obj.logprob_start_len[index] else obj.logprob_start_len[index]
) )
if return_logprob and logprob_start_len == -1:
logprob_start_len = len(input_ids) - 1
top_logprobs_num = ( top_logprobs_num = (
obj.top_logprobs_num obj.top_logprobs_num
if not_use_index if not_use_index
...@@ -245,6 +248,8 @@ class TokenizerManager: ...@@ -245,6 +248,8 @@ class TokenizerManager:
top_logprobs_num = obj.top_logprobs_num[0] top_logprobs_num = obj.top_logprobs_num[0]
if self.is_generation: if self.is_generation:
if return_logprob and logprob_start_len == -1:
logprob_start_len = len(input_ids) - 1
tokenized_obj = TokenizedGenerateReqInput( tokenized_obj = TokenizedGenerateReqInput(
rid, rid,
input_text, input_text,
...@@ -334,6 +339,8 @@ class TokenizerManager: ...@@ -334,6 +339,8 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params[index]) sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation: if self.is_generation:
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
obj.logprob_start_len[index] = len(input_ids) - 1
pixel_values, image_hash, image_size = await self._get_pixel_values( pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data[index] obj.image_data[index]
) )
......
...@@ -61,9 +61,11 @@ class InputMetadata: ...@@ -61,9 +61,11 @@ class InputMetadata:
extend_start_loc: torch.Tensor = None extend_start_loc: torch.Tensor = None
extend_no_prefix: bool = None extend_no_prefix: bool = None
# Output options # For logprob
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None top_logprobs_nums: List[int] = None
extend_seq_lens_cpu: List[int] = None
logprob_start_lens_cpu: List[int] = None
# For multimodal # For multimodal
pixel_values: List[torch.Tensor] = None pixel_values: List[torch.Tensor] = None
...@@ -139,6 +141,7 @@ class InputMetadata: ...@@ -139,6 +141,7 @@ class InputMetadata:
def compute_extend_infos(self, batch: ScheduleBatch): def compute_extend_infos(self, batch: ScheduleBatch):
if self.forward_mode == ForwardMode.DECODE: if self.forward_mode == ForwardMode.DECODE:
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
else: else:
extend_lens_cpu = [ extend_lens_cpu = [
len(r.fill_ids) - batch.prefix_lens_cpu[i] len(r.fill_ids) - batch.prefix_lens_cpu[i]
...@@ -149,6 +152,19 @@ class InputMetadata: ...@@ -149,6 +152,19 @@ class InputMetadata:
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
self.extend_seq_lens_cpu = extend_lens_cpu
self.logprob_start_lens_cpu = [
(
min(
req.logprob_start_len - batch.prefix_lens_cpu[i],
extend_lens_cpu[i] - 1,
)
if req.logprob_start_len >= batch.prefix_lens_cpu[i]
else extend_lens_cpu[i] - 1 # Fake extend, actually decode
)
for i, req in enumerate(batch.reqs)
]
@classmethod @classmethod
def from_schedule_batch( def from_schedule_batch(
cls, cls,
......
...@@ -20,6 +20,7 @@ import json ...@@ -20,6 +20,7 @@ import json
import os import os
import time import time
import uuid import uuid
import warnings
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional from typing import Dict, List, Optional
...@@ -383,20 +384,33 @@ async def v1_retrieve_file_content(file_id: str): ...@@ -383,20 +384,33 @@ async def v1_retrieve_file_content(file_id: str):
return StreamingResponse(iter_file(), media_type="application/octet-stream") return StreamingResponse(iter_file(), media_type="application/octet-stream")
def v1_generate_request(all_requests): def v1_generate_request(all_requests: List[CompletionRequest]):
prompts = [] prompts = []
sampling_params_list = [] sampling_params_list = []
return_logprobs = [] return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = [] top_logprobs_nums = []
first_prompt_type = type(all_requests[0].prompt)
# NOTE: with openai API, the prompt's logprobs are always not computed
first_prompt_type = type(all_requests[0].prompt)
for request in all_requests: for request in all_requests:
prompt = request.prompt
assert ( assert (
type(prompt) == first_prompt_type type(request.prompt) == first_prompt_type
), "All prompts must be of the same type in file input settings" ), "All prompts must be of the same type in file input settings"
prompts.append(prompt) if len(all_requests) > 1 and request.n > 1:
raise ValueError(
"Parallel sampling is not supported for completions from files"
)
if request.echo and request.logprobs:
warnings.warn(
"Echo is not compatible with logprobs. "
"To compute logprobs of input prompt, please use SGLang /request API."
)
for request in all_requests:
prompts.append(request.prompt)
return_logprobs.append(request.logprobs is not None and request.logprobs > 0) return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
logprob_start_lens.append(-1)
top_logprobs_nums.append( top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0 request.logprobs if request.logprobs is not None else 0
) )
...@@ -416,14 +430,11 @@ def v1_generate_request(all_requests): ...@@ -416,14 +430,11 @@ def v1_generate_request(all_requests):
"ignore_eos": request.ignore_eos, "ignore_eos": request.ignore_eos,
} }
) )
if len(all_requests) > 1 and request.n > 1:
raise ValueError(
"Parallel sampling is not supported for completions from files"
)
if len(all_requests) == 1: if len(all_requests) == 1:
prompt = prompts[0] prompt = prompts[0]
sampling_params_list = sampling_params_list[0] sampling_params_list = sampling_params_list[0]
logprob_start_lens = logprob_start_lens[0]
return_logprobs = return_logprobs[0] return_logprobs = return_logprobs[0]
top_logprobs_nums = top_logprobs_nums[0] top_logprobs_nums = top_logprobs_nums[0]
if isinstance(prompt, str) or isinstance(prompt[0], str): if isinstance(prompt, str) or isinstance(prompt[0], str):
...@@ -441,6 +452,7 @@ def v1_generate_request(all_requests): ...@@ -441,6 +452,7 @@ def v1_generate_request(all_requests):
sampling_params=sampling_params_list, sampling_params=sampling_params_list,
return_logprob=return_logprobs, return_logprob=return_logprobs,
top_logprobs_num=top_logprobs_nums, top_logprobs_num=top_logprobs_nums,
logprob_start_len=logprob_start_lens,
return_text_in_logprobs=True, return_text_in_logprobs=True,
stream=all_requests[0].stream, stream=all_requests[0].stream,
) )
...@@ -694,12 +706,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -694,12 +706,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
return response return response
def v1_chat_generate_request(all_requests, tokenizer_manager): def v1_chat_generate_request(
all_requests: List[ChatCompletionRequest], tokenizer_manager
):
input_ids = [] input_ids = []
sampling_params_list = [] sampling_params_list = []
image_data_list = [] image_data_list = []
return_logprobs = [] return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = [] top_logprobs_nums = []
# NOTE: with openai API, the prompt's logprobs are always not computed
for request in all_requests: for request in all_requests:
# Prep the data needed for the underlying GenerateReqInput: # Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string. # - prompt: The full prompt string.
...@@ -732,6 +750,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): ...@@ -732,6 +750,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
image_data = None image_data = None
input_ids.append(prompt_ids) input_ids.append(prompt_ids)
return_logprobs.append(request.logprobs) return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs) top_logprobs_nums.append(request.top_logprobs)
sampling_params_list.append( sampling_params_list.append(
{ {
...@@ -758,17 +777,20 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): ...@@ -758,17 +777,20 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
sampling_params_list = sampling_params_list[0] sampling_params_list = sampling_params_list[0]
image_data = image_data_list[0] image_data = image_data_list[0]
return_logprobs = return_logprobs[0] return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0] top_logprobs_nums = top_logprobs_nums[0]
else: else:
if isinstance(input_ids[0], str): if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids} prompt_kwargs = {"text": input_ids}
else: else:
prompt_kwargs = {"input_ids": input_ids} prompt_kwargs = {"input_ids": input_ids}
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,
image_data=image_data, image_data=image_data,
sampling_params=sampling_params_list, sampling_params=sampling_params_list,
return_logprob=return_logprobs, return_logprob=return_logprobs,
logprob_start_len=logprob_start_lens,
top_logprobs_num=top_logprobs_nums, top_logprobs_num=top_logprobs_nums,
stream=all_requests[0].stream, stream=all_requests[0].stream,
return_text_in_logprobs=True, return_text_in_logprobs=True,
......
...@@ -559,12 +559,14 @@ class Runtime: ...@@ -559,12 +559,14 @@ class Runtime:
prompt: str, prompt: str,
sampling_params: Optional[Dict] = None, sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False, return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None,
): ):
json_data = { json_data = {
"text": prompt, "text": prompt,
"sampling_params": sampling_params, "sampling_params": sampling_params,
"return_logprob": return_logprob, "return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num, "top_logprobs_num": top_logprobs_num,
} }
response = requests.post( response = requests.post(
......
...@@ -209,6 +209,7 @@ class SRTRunner: ...@@ -209,6 +209,7 @@ class SRTRunner:
prompt, prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
return_logprob=True, return_logprob=True,
logprob_start_len=0,
top_logprobs_num=NUM_TOP_LOGPROBS, top_logprobs_num=NUM_TOP_LOGPROBS,
) )
response = json.loads(response) response = json.loads(response)
......
...@@ -70,12 +70,11 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -70,12 +70,11 @@ class TestOpenAIServer(unittest.TestCase):
assert isinstance(response.choices[0].logprobs.tokens[0], str) assert isinstance(response.choices[0].logprobs.tokens[0], str)
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict) assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1]) ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0 assert ret_num_top_logprobs > 0
if echo:
assert response.choices[0].logprobs.token_logprobs[0] == None
else:
assert response.choices[0].logprobs.token_logprobs[0] != None assert response.choices[0].logprobs.token_logprobs[0] != None
assert response.id assert response.id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment