Unverified Commit 30db99b3 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs...

Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)
parent 0a409bd4
...@@ -13,7 +13,7 @@ class GenerateReqInput: ...@@ -13,7 +13,7 @@ class GenerateReqInput:
# The image input. It can be a file name, a url, or base64 encoded string. # The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image. # See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None image_data: Optional[Union[List[str], str]] = None
# The sampling_params. # The sampling_params. See descriptions below.
sampling_params: Union[List[Dict], Dict] = None sampling_params: Union[List[Dict], Dict] = None
# The request id. # The request id.
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
...@@ -23,7 +23,7 @@ class GenerateReqInput: ...@@ -23,7 +23,7 @@ class GenerateReqInput:
logprob_start_len: Optional[Union[List[int], int]] = None logprob_start_len: Optional[Union[List[int], int]] = None
# The number of top logprobs to return. # The number of top logprobs to return.
top_logprobs_num: Optional[Union[List[int], int]] = None top_logprobs_num: Optional[Union[List[int], int]] = None
# Whether to detokenize tokens in logprobs. # Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False return_text_in_logprobs: bool = False
# Whether to stream output. # Whether to stream output.
stream: bool = False stream: bool = False
...@@ -32,27 +32,28 @@ class GenerateReqInput: ...@@ -32,27 +32,28 @@ class GenerateReqInput:
The `sampling_params` follows this format The `sampling_params` follows this format
```python ```python
class SamplingParams: # The maximum number of output tokens
def __init__( max_new_tokens: int = 16,
self, # Stop when hitting any of the strings in this list.
max_new_tokens: int = 16, stop: Optional[Union[str, List[str]]] = None,
stop: Optional[Union[str, List[str]]] = None, # Sampling temperature
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, # Top-p sampling
top_k: int = -1, top_p: float = 1.0,
frequency_penalty: float = 0.0, # Top-k sampling
presence_penalty: float = 0.0, top_k: int = -1,
ignore_eos: bool = False, # Whether to ignore EOS token.
skip_special_tokens: bool = True, ignore_eos: bool = False,
dtype: Optional[str] = None, # Whether to skip the special tokens during detokenization.
regex: Optional[str] = None, skip_special_tokens: bool = True,
) -> None: # Whether to add spaces between special tokens during detokenization.
spaces_between_special_tokens: bool = True,
# Constrains the output to follow a given regular expression.
regex: Optional[str] = None,
# Do parallel sampling and return `n` outputs.
n: int = 1,
``` ```
- `max_new_tokens`, `stop`, `temperature`, `top_p`, `top_k` are common sampling parameters.
- `ignore_eos` means ignoring the EOS token and continue decoding, which is helpful for benchmarking purposes.
- `regex` constrains the output to follow a given regular expression.
## Examples ## Examples
### Normal ### Normal
......
...@@ -20,8 +20,8 @@ def main(): ...@@ -20,8 +20,8 @@ def main():
print("questions:", question) print("questions:", question)
print("choice:", state["tool"]) print("choice:", state["tool"])
meta_info = state.get_meta_info("tool") meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0]) print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1]) print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
print("-" * 50) print("-" * 50)
# Run a batch # Run a batch
...@@ -34,8 +34,8 @@ def main(): ...@@ -34,8 +34,8 @@ def main():
print("questions:", question) print("questions:", question)
print("choice:", state["tool"]) print("choice:", state["tool"])
meta_info = state.get_meta_info("tool") meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0]) print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1]) print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
print("-" * 50) print("-" * 50)
......
...@@ -31,7 +31,7 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose): ...@@ -31,7 +31,7 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
top_logprobs_num=get_top_k, top_logprobs_num=get_top_k,
return_text_in_logprobs=True, return_text_in_logprobs=True,
) )
logprobs = step_0.get_meta_info("get_top_k")["decode_top_logprobs"][0] logprobs = step_0.get_meta_info("get_top_k")["output_top_logprobs"][0]
print("Decoding step 0:", ", ".join(pformat(token[2]) for token in logprobs)) print("Decoding step 0:", ", ".join(pformat(token[2]) for token in logprobs))
for idx, (f, token) in enumerate(zip(forks, logprobs)): for idx, (f, token) in enumerate(zip(forks, logprobs)):
...@@ -55,9 +55,9 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose): ...@@ -55,9 +55,9 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
) )
# calculate probability disparity between the top and secondary tokens # calculate probability disparity between the top and secondary tokens
x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]] x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]] x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
tokens = [xt[0][2] for xt in f.get_meta_info("answer")["decode_top_logprobs"]] tokens = [xt[0][2] for xt in f.get_meta_info("answer")["output_top_logprobs"]]
delta = (sum(x1s) - sum(x2s)) / len(x1s) delta = (sum(x1s) - sum(x2s)) / len(x1s)
# extract the answer span (without the '<|end_of_text|>' token) # extract the answer span (without the '<|end_of_text|>' token)
...@@ -81,19 +81,19 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose): ...@@ -81,19 +81,19 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
answer_tokens = [ answer_tokens = [
xt[0][2] xt[0][2]
for xt in answer_forks[idx].get_meta_info("answer_span")[ for xt in answer_forks[idx].get_meta_info("answer_span")[
"decode_top_logprobs" "output_top_logprobs"
] ]
] ]
answer_x1s = [ answer_x1s = [
exp(xt[0][0]) exp(xt[0][0])
for xt in answer_forks[idx].get_meta_info("answer_span")[ for xt in answer_forks[idx].get_meta_info("answer_span")[
"decode_top_logprobs" "output_top_logprobs"
] ]
] ]
answer_x2s = [ answer_x2s = [
exp(xt[1][0]) exp(xt[1][0])
for xt in answer_forks[idx].get_meta_info("answer_span")[ for xt in answer_forks[idx].get_meta_info("answer_span")[
"decode_top_logprobs" "output_top_logprobs"
] ]
] ]
......
...@@ -56,14 +56,14 @@ def srt_api_request(name): ...@@ -56,14 +56,14 @@ def srt_api_request(name):
# fout.write(json.dumps(res, indent=4)) # fout.write(json.dumps(res, indent=4))
meta_info = res["meta_info"] meta_info = res["meta_info"]
assert len(meta_info["prefill_token_logprobs"]) == len( assert len(meta_info["input_token_logprobs"]) == len(
meta_info["prefill_top_logprobs"] meta_info["input_top_logprobs"]
) )
assert len(meta_info["decode_token_logprobs"]) == len( assert len(meta_info["output_token_logprobs"]) == len(
meta_info["decode_top_logprobs"] meta_info["output_top_logprobs"]
) )
assert len(meta_info["prefill_token_logprobs"]) == meta_info["prompt_tokens"] assert len(meta_info["input_token_logprobs"]) == meta_info["prompt_tokens"]
assert len(meta_info["decode_token_logprobs"]) == meta_info["completion_tokens"] - 1 assert len(meta_info["output_token_logprobs"]) == meta_info["completion_tokens"] - 1
return res return res
...@@ -72,11 +72,11 @@ def pretty_print(res): ...@@ -72,11 +72,11 @@ def pretty_print(res):
meta_info = res["meta_info"] meta_info = res["meta_info"]
print("\n\n", "=" * 30, "Prefill", "=" * 30) print("\n\n", "=" * 30, "Prefill", "=" * 30)
for i in range(len(meta_info["prefill_token_logprobs"])): for i in range(len(meta_info["input_token_logprobs"])):
print(f"{str(meta_info['prefill_token_logprobs'][i][2].encode()): <20}", end="") print(f"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = ( top_ks = (
[str(t[2].encode()) for t in meta_info["prefill_top_logprobs"][i]] [str(t[2].encode()) for t in meta_info["input_top_logprobs"][i]]
if meta_info["prefill_top_logprobs"][i] if meta_info["input_top_logprobs"][i]
else [] else []
) )
for top_k in top_ks: for top_k in top_ks:
...@@ -84,9 +84,9 @@ def pretty_print(res): ...@@ -84,9 +84,9 @@ def pretty_print(res):
print() print()
print("\n\n", "=" * 30, "Decode", "=" * 30) print("\n\n", "=" * 30, "Decode", "=" * 30)
for i in range(len(meta_info["decode_token_logprobs"])): for i in range(len(meta_info["output_token_logprobs"])):
print(f"{str(meta_info['decode_token_logprobs'][i][2].encode()): <20}", end="") print(f"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = [str(t[2].encode()) for t in meta_info["decode_top_logprobs"][i]] top_ks = [str(t[2].encode()) for t in meta_info["output_top_logprobs"][i]]
for top_k in top_ks: for top_k in top_ks:
print(f"{top_k: <15}", end="") print(f"{top_k: <15}", end="")
print() print()
......
...@@ -253,14 +253,14 @@ class RuntimeEndpoint(BaseBackend): ...@@ -253,14 +253,14 @@ class RuntimeEndpoint(BaseBackend):
r["meta_info"]["normalized_prompt_logprob"] for r in obj r["meta_info"]["normalized_prompt_logprob"] for r in obj
] ]
decision = choices[np.argmax(normalized_prompt_logprobs)] decision = choices[np.argmax(normalized_prompt_logprobs)]
prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj] input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj] output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
return ( return (
decision, decision,
normalized_prompt_logprobs, normalized_prompt_logprobs,
prefill_token_logprobs, input_token_logprobs,
decode_token_logprobs, output_token_logprobs,
) )
def concatenate_and_append(self, src_rids: List[str], dst_rid: str): def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
......
...@@ -541,16 +541,16 @@ class StreamExecutor: ...@@ -541,16 +541,16 @@ class StreamExecutor:
( (
decision, decision,
normalized_prompt_logprobs, normalized_prompt_logprobs,
prefill_token_logprobs, input_token_logprobs,
decode_token_logprobs, output_token_logprobs,
) = self.backend.select(self, expr.choices, expr.temperature) ) = self.backend.select(self, expr.choices, expr.temperature)
if expr.name is not None: if expr.name is not None:
name = expr.name name = expr.name
self.variables[name] = decision self.variables[name] = decision
self.meta_info[name] = { self.meta_info[name] = {
"normalized_prompt_logprobs": normalized_prompt_logprobs, "normalized_prompt_logprobs": normalized_prompt_logprobs,
"prefill_token_logprobs": prefill_token_logprobs, "input_token_logprobs": input_token_logprobs,
"decode_token_logprobs": decode_token_logprobs, "output_token_logprobs": output_token_logprobs,
} }
self.variable_event[name].set() self.variable_event[name].set()
self.text_ += decision self.text_ += decision
......
...@@ -22,13 +22,13 @@ class LogitProcessorOutput: ...@@ -22,13 +22,13 @@ class LogitProcessorOutput:
# The normlaized logprobs of prompts. shape: [#seq] # The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor normalized_prompt_logprobs: torch.Tensor
# The logprobs of prefill tokens. shape: [#token, vocab_size] # The logprobs of input tokens. shape: [#token, vocab_size]
prefill_token_logprobs: torch.Tensor input_token_logprobs: torch.Tensor
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id) # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
prefill_top_logprobs: List input_top_logprobs: List
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id) # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
decode_top_logprobs: List output_top_logprobs: List
@dataclasses.dataclass @dataclasses.dataclass
...@@ -58,20 +58,16 @@ class LogitsProcessor(nn.Module): ...@@ -58,20 +58,16 @@ class LogitsProcessor(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
def _get_normalized_prompt_logprobs( def _get_normalized_prompt_logprobs(
self, prefill_token_logprobs, logits_metadata: LogitsMetadata self, input_token_logprobs, logits_metadata: LogitsMetadata
): ):
logprobs_cumsum = torch.cumsum( logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
prefill_token_logprobs, dim=0, dtype=torch.float32
)
start = logits_metadata.extend_start_loc.clone() start = logits_metadata.extend_start_loc.clone()
end = start + logits_metadata.extend_seq_lens - 2 end = start + logits_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1) start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1) end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
sum_logp = ( sum_logp = (
logprobs_cumsum[end] logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
- logprobs_cumsum[start]
+ prefill_token_logprobs[start]
) )
normalized_prompt_logprobs = sum_logp / ( normalized_prompt_logprobs = sum_logp / (
(logits_metadata.extend_seq_lens - 1).clamp(min=1) (logits_metadata.extend_seq_lens - 1).clamp(min=1)
...@@ -83,34 +79,34 @@ class LogitsProcessor(nn.Module): ...@@ -83,34 +79,34 @@ class LogitsProcessor(nn.Module):
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata): def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
# TODO: vectorize the code below # TODO: vectorize the code below
if logits_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = [] output_top_logprobs = []
for i in range(all_logprobs.shape[0]): for i in range(all_logprobs.shape[0]):
k = logits_metadata.top_logprobs_nums[i] k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k) t = all_logprobs[i].topk(k)
v_cpu = t.values.tolist() v_cpu = t.values.tolist()
p_cpu = t.indices.tolist() p_cpu = t.indices.tolist()
decode_top_logprobs.append(list(zip(v_cpu, p_cpu))) output_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, decode_top_logprobs return None, output_top_logprobs
else: else:
prefill_top_logprobs, decode_top_logprobs = [], [] input_top_logprobs, output_top_logprobs = [], []
pt = 0 pt = 0
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist() extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu): for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0: if extend_seq_len == 0:
prefill_top_logprobs.append([]) input_top_logprobs.append([])
decode_top_logprobs.append([]) output_top_logprobs.append([])
continue continue
k = logits_metadata.top_logprobs_nums[i] k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_len].topk(k) t = all_logprobs[pt : pt + extend_seq_len].topk(k)
vs_cpu = t.values.tolist() vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist() ps_cpu = t.indices.tolist()
prefill_top_logprobs.append( input_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)] [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
) )
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1]))) output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_len pt += extend_seq_len
return prefill_top_logprobs, decode_top_logprobs return input_top_logprobs, output_top_logprobs
def forward( def forward(
self, self,
...@@ -150,9 +146,9 @@ class LogitsProcessor(nn.Module): ...@@ -150,9 +146,9 @@ class LogitsProcessor(nn.Module):
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=None, next_token_logprobs=None,
normalized_prompt_logprobs=None, normalized_prompt_logprobs=None,
prefill_token_logprobs=None, input_token_logprobs=None,
prefill_top_logprobs=None, input_top_logprobs=None,
decode_top_logprobs=None, output_top_logprobs=None,
) )
else: else:
# When logprob is requested, compute the logits for all tokens. # When logprob is requested, compute the logits for all tokens.
...@@ -164,19 +160,19 @@ class LogitsProcessor(nn.Module): ...@@ -164,19 +160,19 @@ class LogitsProcessor(nn.Module):
x > 0 for x in logits_metadata.top_logprobs_nums x > 0 for x in logits_metadata.top_logprobs_nums
) )
if return_top_logprob: if return_top_logprob:
decode_top_logprobs = self.get_top_logprobs( output_top_logprobs = self.get_top_logprobs(
last_logprobs, logits_metadata last_logprobs, logits_metadata
)[1] )[1]
else: else:
decode_top_logprobs = None output_top_logprobs = None
return LogitProcessorOutput( return LogitProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=last_logprobs, next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None, normalized_prompt_logprobs=None,
prefill_token_logprobs=None, input_token_logprobs=None,
prefill_top_logprobs=None, input_top_logprobs=None,
decode_top_logprobs=decode_top_logprobs, output_top_logprobs=output_top_logprobs,
) )
else: else:
all_logits = torch.matmul(hidden_states, weight.T) all_logits = torch.matmul(hidden_states, weight.T)
...@@ -193,32 +189,32 @@ class LogitsProcessor(nn.Module): ...@@ -193,32 +189,32 @@ class LogitsProcessor(nn.Module):
x > 0 for x in logits_metadata.top_logprobs_nums x > 0 for x in logits_metadata.top_logprobs_nums
) )
if return_top_logprob: if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs( input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
all_logprobs, logits_metadata all_logprobs, logits_metadata
) )
else: else:
prefill_top_logprobs = decode_top_logprobs = None input_top_logprobs = output_top_logprobs = None
last_logprobs = all_logprobs[last_index] last_logprobs = all_logprobs[last_index]
# Compute the logprobs and normalized logprobs for the prefill tokens. # Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation. # Note that we pad a zero at the end of each sequence for easy computation.
prefill_token_logprobs = all_logprobs[ input_token_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"), torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
] ]
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
prefill_token_logprobs, logits_metadata input_token_logprobs, logits_metadata
) )
return LogitProcessorOutput( return LogitProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=last_logprobs, next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs,
prefill_token_logprobs=prefill_token_logprobs, input_token_logprobs=input_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs, input_top_logprobs=input_top_logprobs,
decode_top_logprobs=decode_top_logprobs, output_top_logprobs=output_top_logprobs,
) )
......
...@@ -226,9 +226,9 @@ class CudaGraphRunner: ...@@ -226,9 +226,9 @@ class CudaGraphRunner:
next_token_logits=output.next_token_logits[:raw_bs], next_token_logits=output.next_token_logits[:raw_bs],
next_token_logprobs=None, next_token_logprobs=None,
normalized_prompt_logprobs=None, normalized_prompt_logprobs=None,
prefill_token_logprobs=None, input_token_logprobs=None,
prefill_top_logprobs=None, input_top_logprobs=None,
decode_top_logprobs=None, output_top_logprobs=None,
) )
# Extract logprobs # Extract logprobs
...@@ -242,7 +242,7 @@ class CudaGraphRunner: ...@@ -242,7 +242,7 @@ class CudaGraphRunner:
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
) )
output.decode_top_logprobs = LogitsProcessor.get_top_logprobs( output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
output.next_token_logprobs, logits_metadata output.next_token_logprobs, logits_metadata
)[1] )[1]
......
...@@ -124,10 +124,10 @@ class Req: ...@@ -124,10 +124,10 @@ class Req:
self.logprob_start_len = 0 self.logprob_start_len = 0
self.top_logprobs_num = 0 self.top_logprobs_num = 0
self.normalized_prompt_logprob = None self.normalized_prompt_logprob = None
self.prefill_token_logprobs = None self.input_token_logprobs = None
self.prefill_top_logprobs = None self.input_top_logprobs = None
self.decode_token_logprobs = [] self.output_token_logprobs = []
self.decode_top_logprobs = [] self.output_top_logprobs = []
# The tokens is prefilled but need to be considered as decode tokens # The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs # and should be updated for the decode logprobs
self.last_update_decode_tokens = 0 self.last_update_decode_tokens = 0
...@@ -244,8 +244,8 @@ class Req: ...@@ -244,8 +244,8 @@ class Req:
k = k + 1 k = k + 1
else: else:
break break
self.decode_token_logprobs = self.decode_token_logprobs[:k] self.output_token_logprobs = self.output_token_logprobs[:k]
self.decode_top_logprobs = self.decode_top_logprobs[:k] self.output_top_logprobs = self.output_top_logprobs[:k]
self.logprob_start_len = prompt_tokens + k self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.output_ids) - k self.last_update_decode_tokens = len(self.output_ids) - k
......
...@@ -455,7 +455,7 @@ class ModelTpServer: ...@@ -455,7 +455,7 @@ class ModelTpServer:
torch.arange(len(next_token_ids), device=next_token_ids.device), torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids, next_token_ids,
].tolist() ].tolist()
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist() output.input_token_logprobs = output.input_token_logprobs.tolist()
output.normalized_prompt_logprobs = ( output.normalized_prompt_logprobs = (
output.normalized_prompt_logprobs.tolist() output.normalized_prompt_logprobs.tolist()
) )
...@@ -481,24 +481,24 @@ class ModelTpServer: ...@@ -481,24 +481,24 @@ class ModelTpServer:
if req.normalized_prompt_logprob is None: if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
if req.prefill_token_logprobs is None: if req.input_token_logprobs is None:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list( req.input_token_logprobs = list(
zip( zip(
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1], output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
req.input_ids[-req.extend_input_len + 1 :], req.input_ids[-req.extend_input_len + 1 :],
) )
) )
if req.logprob_start_len == 0: if req.logprob_start_len == 0:
req.prefill_token_logprobs = [ req.input_token_logprobs = [
(None, req.input_ids[0]) (None, req.input_ids[0])
] + req.prefill_token_logprobs ] + req.input_token_logprobs
if req.last_update_decode_tokens != 0: if req.last_update_decode_tokens != 0:
req.decode_token_logprobs.extend( req.output_token_logprobs.extend(
list( list(
zip( zip(
output.prefill_token_logprobs[ output.input_token_logprobs[
pt pt
+ req.extend_input_len + req.extend_input_len
- req.last_update_decode_tokens : pt - req.last_update_decode_tokens : pt
...@@ -510,21 +510,21 @@ class ModelTpServer: ...@@ -510,21 +510,21 @@ class ModelTpServer:
) )
) )
req.decode_token_logprobs.append( req.output_token_logprobs.append(
(output.next_token_logprobs[i], next_token_ids[i]) (output.next_token_logprobs[i], next_token_ids[i])
) )
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
if req.prefill_top_logprobs is None: if req.input_top_logprobs is None:
req.prefill_top_logprobs = output.prefill_top_logprobs[i] req.input_top_logprobs = output.input_top_logprobs[i]
if req.logprob_start_len == 0: if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs req.input_top_logprobs = [None] + req.input_top_logprobs
if req.last_update_decode_tokens != 0: if req.last_update_decode_tokens != 0:
req.decode_top_logprobs.extend( req.output_top_logprobs.extend(
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :] output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
) )
req.decode_top_logprobs.append(output.decode_top_logprobs[i]) req.output_top_logprobs.append(output.output_top_logprobs[i])
def cache_filled_batch(self, batch: Batch): def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
...@@ -589,11 +589,11 @@ class ModelTpServer: ...@@ -589,11 +589,11 @@ class ModelTpServer:
req.check_finished() req.check_finished()
if req.return_logprob: if req.return_logprob:
req.decode_token_logprobs.append( req.output_token_logprobs.append(
(next_token_logprobs[i], next_token_id) (next_token_logprobs[i], next_token_id)
) )
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
req.decode_top_logprobs.append(output.decode_top_logprobs[i]) req.output_top_logprobs.append(output.output_top_logprobs[i])
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
...@@ -645,16 +645,16 @@ class ModelTpServer: ...@@ -645,16 +645,16 @@ class ModelTpServer:
} }
if req.return_logprob: if req.return_logprob:
( (
meta_info["prefill_token_logprobs"], meta_info["input_token_logprobs"],
meta_info["decode_token_logprobs"], meta_info["output_token_logprobs"],
meta_info["prefill_top_logprobs"], meta_info["input_top_logprobs"],
meta_info["decode_top_logprobs"], meta_info["output_top_logprobs"],
meta_info["normalized_prompt_logprob"], meta_info["normalized_prompt_logprob"],
) = ( ) = (
req.prefill_token_logprobs, req.input_token_logprobs,
req.decode_token_logprobs, req.output_token_logprobs,
req.prefill_top_logprobs, req.input_top_logprobs,
req.decode_top_logprobs, req.output_top_logprobs,
req.normalized_prompt_logprob, req.normalized_prompt_logprob,
) )
output_meta_info.append(meta_info) output_meta_info.append(meta_info)
......
...@@ -20,7 +20,7 @@ class GenerateReqInput: ...@@ -20,7 +20,7 @@ class GenerateReqInput:
# The image input. It can be a file name, a url, or base64 encoded string. # The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image. # See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None image_data: Optional[Union[List[str], str]] = None
# The sampling_params. # The sampling_params. See descriptions below.
sampling_params: Union[List[Dict], Dict] = None sampling_params: Union[List[Dict], Dict] = None
# The request id. # The request id.
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
...@@ -30,7 +30,7 @@ class GenerateReqInput: ...@@ -30,7 +30,7 @@ class GenerateReqInput:
logprob_start_len: Optional[Union[List[int], int]] = None logprob_start_len: Optional[Union[List[int], int]] = None
# The number of top logprobs to return. # The number of top logprobs to return.
top_logprobs_num: Optional[Union[List[int], int]] = None top_logprobs_num: Optional[Union[List[int], int]] = None
# Whether to detokenize tokens in logprobs. # Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False return_text_in_logprobs: bool = False
# Whether to stream output. # Whether to stream output.
stream: bool = False stream: bool = False
......
...@@ -448,23 +448,23 @@ class TokenizerManager: ...@@ -448,23 +448,23 @@ class TokenizerManager:
return_text_in_logprobs: bool, return_text_in_logprobs: bool,
): ):
if return_logprob: if return_logprob:
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens( ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
) )
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens( ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
) )
if top_logprobs_num > 0: if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = ( ret["meta_info"]["input_top_logprobs"] = (
self.detokenize_top_logprobs_tokens( self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], ret["meta_info"]["input_top_logprobs"],
return_text_in_logprobs, return_text_in_logprobs,
) )
) )
ret["meta_info"]["decode_top_logprobs"] = ( ret["meta_info"]["output_top_logprobs"] = (
self.detokenize_top_logprobs_tokens( self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
) )
) )
return ret return ret
......
...@@ -54,9 +54,9 @@ class LlamaForClassification(nn.Module): ...@@ -54,9 +54,9 @@ class LlamaForClassification(nn.Module):
next_token_logits=scores, next_token_logits=scores,
next_token_logprobs=scores, next_token_logprobs=scores,
normalized_prompt_logprobs=scores, normalized_prompt_logprobs=scores,
prefill_token_logprobs=torch.ones_like(input_ids), input_token_logprobs=torch.ones_like(input_ids),
prefill_top_logprobs=None, input_top_logprobs=None,
decode_top_logprobs=None, output_top_logprobs=None,
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -140,29 +140,29 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -140,29 +140,29 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if request.logprobs: if request.logprobs:
# The first chunk and echo is enabled. # The first chunk and echo is enabled.
if not stream_buffer and request.echo: if not stream_buffer and request.echo:
prefill_token_logprobs = content["meta_info"][ input_token_logprobs = content["meta_info"][
"prefill_token_logprobs" "input_token_logprobs"
] ]
prefill_top_logprobs = content["meta_info"][ input_top_logprobs = content["meta_info"][
"prefill_top_logprobs" "input_top_logprobs"
] ]
else: else:
prefill_token_logprobs = None input_token_logprobs = None
prefill_top_logprobs = None input_top_logprobs = None
logprobs = to_openai_style_logprobs( logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs, input_token_logprobs=input_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs, input_top_logprobs=input_top_logprobs,
decode_token_logprobs=content["meta_info"][ output_token_logprobs=content["meta_info"][
"decode_token_logprobs" "output_token_logprobs"
][n_prev_token:], ][n_prev_token:],
decode_top_logprobs=content["meta_info"][ output_top_logprobs=content["meta_info"][
"decode_top_logprobs" "output_top_logprobs"
][n_prev_token:], ][n_prev_token:],
) )
n_prev_token = len( n_prev_token = len(
content["meta_info"]["decode_token_logprobs"] content["meta_info"]["output_token_logprobs"]
) )
else: else:
logprobs = None logprobs = None
...@@ -218,17 +218,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -218,17 +218,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if request.logprobs: if request.logprobs:
if request.echo: if request.echo:
prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"] input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"] input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
else: else:
prefill_token_logprobs = None input_token_logprobs = None
prefill_top_logprobs = None input_top_logprobs = None
logprobs = to_openai_style_logprobs( logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs, input_token_logprobs=input_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs, input_top_logprobs=input_top_logprobs,
decode_token_logprobs=ret_item["meta_info"]["decode_token_logprobs"], output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
decode_top_logprobs=ret_item["meta_info"]["decode_top_logprobs"], output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
) )
else: else:
logprobs = None logprobs = None
...@@ -401,10 +401,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -401,10 +401,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
def to_openai_style_logprobs( def to_openai_style_logprobs(
prefill_token_logprobs=None, input_token_logprobs=None,
decode_token_logprobs=None, output_token_logprobs=None,
prefill_top_logprobs=None, input_top_logprobs=None,
decode_top_logprobs=None, output_top_logprobs=None,
): ):
ret_logprobs = LogProbs() ret_logprobs = LogProbs()
...@@ -425,13 +425,13 @@ def to_openai_style_logprobs( ...@@ -425,13 +425,13 @@ def to_openai_style_logprobs(
else: else:
ret_logprobs.top_logprobs.append(None) ret_logprobs.top_logprobs.append(None)
if prefill_token_logprobs is not None: if input_token_logprobs is not None:
append_token_logprobs(prefill_token_logprobs) append_token_logprobs(input_token_logprobs)
if decode_token_logprobs is not None: if output_token_logprobs is not None:
append_token_logprobs(decode_token_logprobs) append_token_logprobs(output_token_logprobs)
if prefill_top_logprobs is not None: if input_top_logprobs is not None:
append_top_logprobs(prefill_top_logprobs) append_top_logprobs(input_top_logprobs)
if decode_top_logprobs is not None: if output_top_logprobs is not None:
append_top_logprobs(decode_top_logprobs) append_top_logprobs(output_top_logprobs)
return ret_logprobs return ret_logprobs
...@@ -13,14 +13,15 @@ import json ...@@ -13,14 +13,15 @@ import json
import requests import requests
def test_decode(url, return_logprob, top_logprobs_num, return_text): def test_decode(url, return_logprob=False, top_logprobs_num=0, return_text=False, n=1):
response = requests.post( response = requests.post(
url + "/generate", url + "/generate",
json={ json={
"text": "The capital of France is", "text": "The capital of France is",
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 32, "max_new_tokens": 32,
"n": n,
}, },
"stream": False, "stream": False,
"return_logprob": return_logprob, "return_logprob": return_logprob,
...@@ -41,8 +42,14 @@ if __name__ == "__main__": ...@@ -41,8 +42,14 @@ if __name__ == "__main__":
url = f"{args.host}:{args.port}" url = f"{args.host}:{args.port}"
test_decode(url, False, 0, False) test_decode(url)
test_decode(url, True, 0, False) test_decode(url, n=3)
test_decode(url, True, 0, True)
test_decode(url, True, 3, False) for top_logprobs_num in [0, 3]:
test_decode(url, True, 3, True) for return_text in [True, False]:
test_decode(
url,
return_logprob=True,
top_logprobs_num=top_logprobs_num,
return_text=return_text,
)
...@@ -40,14 +40,14 @@ def test_decode_stream(url, return_logprob, top_logprobs_num): ...@@ -40,14 +40,14 @@ def test_decode_stream(url, return_logprob, top_logprobs_num):
data = json.loads(chunk[5:].strip("\n")) data = json.loads(chunk[5:].strip("\n"))
if return_logprob: if return_logprob:
assert data["meta_info"]["prefill_token_logprobs"] is not None assert data["meta_info"]["input_token_logprobs"] is not None
assert data["meta_info"]["decode_token_logprobs"] is not None assert data["meta_info"]["output_token_logprobs"] is not None
assert data["meta_info"]["normalized_prompt_logprob"] is not None assert data["meta_info"]["normalized_prompt_logprob"] is not None
for logprob, token_id, token_text in data["meta_info"][ for logprob, token_id, token_text in data["meta_info"][
"decode_token_logprobs" "output_token_logprobs"
][prev:]: ][prev:]:
print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True) print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True)
prev = len(data["meta_info"]["decode_token_logprobs"]) prev = len(data["meta_info"]["output_token_logprobs"])
else: else:
output = data["text"].strip() output = data["text"].strip()
print(output[prev:], end="", flush=True) print(output[prev:], end="", flush=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment