Unverified Commit 30db99b3 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs...

Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)
parent 0a409bd4
......@@ -13,7 +13,7 @@ class GenerateReqInput:
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
# The sampling_params.
# The sampling_params. See descriptions below.
sampling_params: Union[List[Dict], Dict] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
......@@ -23,7 +23,7 @@ class GenerateReqInput:
logprob_start_len: Optional[Union[List[int], int]] = None
# The number of top logprobs to return.
top_logprobs_num: Optional[Union[List[int], int]] = None
# Whether to detokenize tokens in logprobs.
# Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False
# Whether to stream output.
stream: bool = False
......@@ -32,27 +32,28 @@ class GenerateReqInput:
The `sampling_params` follows this format
```python
class SamplingParams:
def __init__(
self,
max_new_tokens: int = 16,
stop: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
dtype: Optional[str] = None,
regex: Optional[str] = None,
) -> None:
# The maximum number of output tokens
max_new_tokens: int = 16,
# Stop when hitting any of the strings in this list.
stop: Optional[Union[str, List[str]]] = None,
# Sampling temperature
temperature: float = 1.0,
# Top-p sampling
top_p: float = 1.0,
# Top-k sampling
top_k: int = -1,
# Whether to ignore EOS token.
ignore_eos: bool = False,
# Whether to skip the special tokens during detokenization.
skip_special_tokens: bool = True,
# Whether to add spaces between special tokens during detokenization.
spaces_between_special_tokens: bool = True,
# Constrains the output to follow a given regular expression.
regex: Optional[str] = None,
# Do parallel sampling and return `n` outputs.
n: int = 1,
```
- `max_new_tokens`, `stop`, `temperature`, `top_p`, `top_k` are common sampling parameters.
- `ignore_eos` means ignoring the EOS token and continue decoding, which is helpful for benchmarking purposes.
- `regex` constrains the output to follow a given regular expression.
## Examples
### Normal
......
......@@ -20,8 +20,8 @@ def main():
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0])
print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1])
print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
print("-" * 50)
# Run a batch
......@@ -34,8 +34,8 @@ def main():
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0])
print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1])
print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
print("-" * 50)
......
......@@ -31,7 +31,7 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
top_logprobs_num=get_top_k,
return_text_in_logprobs=True,
)
logprobs = step_0.get_meta_info("get_top_k")["decode_top_logprobs"][0]
logprobs = step_0.get_meta_info("get_top_k")["output_top_logprobs"][0]
print("Decoding step 0:", ", ".join(pformat(token[2]) for token in logprobs))
for idx, (f, token) in enumerate(zip(forks, logprobs)):
......@@ -55,9 +55,9 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
)
# calculate probability disparity between the top and secondary tokens
x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
tokens = [xt[0][2] for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
tokens = [xt[0][2] for xt in f.get_meta_info("answer")["output_top_logprobs"]]
delta = (sum(x1s) - sum(x2s)) / len(x1s)
# extract the answer span (without the '<|end_of_text|>' token)
......@@ -81,19 +81,19 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
answer_tokens = [
xt[0][2]
for xt in answer_forks[idx].get_meta_info("answer_span")[
"decode_top_logprobs"
"output_top_logprobs"
]
]
answer_x1s = [
exp(xt[0][0])
for xt in answer_forks[idx].get_meta_info("answer_span")[
"decode_top_logprobs"
"output_top_logprobs"
]
]
answer_x2s = [
exp(xt[1][0])
for xt in answer_forks[idx].get_meta_info("answer_span")[
"decode_top_logprobs"
"output_top_logprobs"
]
]
......
......@@ -56,14 +56,14 @@ def srt_api_request(name):
# fout.write(json.dumps(res, indent=4))
meta_info = res["meta_info"]
assert len(meta_info["prefill_token_logprobs"]) == len(
meta_info["prefill_top_logprobs"]
assert len(meta_info["input_token_logprobs"]) == len(
meta_info["input_top_logprobs"]
)
assert len(meta_info["decode_token_logprobs"]) == len(
meta_info["decode_top_logprobs"]
assert len(meta_info["output_token_logprobs"]) == len(
meta_info["output_top_logprobs"]
)
assert len(meta_info["prefill_token_logprobs"]) == meta_info["prompt_tokens"]
assert len(meta_info["decode_token_logprobs"]) == meta_info["completion_tokens"] - 1
assert len(meta_info["input_token_logprobs"]) == meta_info["prompt_tokens"]
assert len(meta_info["output_token_logprobs"]) == meta_info["completion_tokens"] - 1
return res
......@@ -72,11 +72,11 @@ def pretty_print(res):
meta_info = res["meta_info"]
print("\n\n", "=" * 30, "Prefill", "=" * 30)
for i in range(len(meta_info["prefill_token_logprobs"])):
print(f"{str(meta_info['prefill_token_logprobs'][i][2].encode()): <20}", end="")
for i in range(len(meta_info["input_token_logprobs"])):
print(f"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = (
[str(t[2].encode()) for t in meta_info["prefill_top_logprobs"][i]]
if meta_info["prefill_top_logprobs"][i]
[str(t[2].encode()) for t in meta_info["input_top_logprobs"][i]]
if meta_info["input_top_logprobs"][i]
else []
)
for top_k in top_ks:
......@@ -84,9 +84,9 @@ def pretty_print(res):
print()
print("\n\n", "=" * 30, "Decode", "=" * 30)
for i in range(len(meta_info["decode_token_logprobs"])):
print(f"{str(meta_info['decode_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = [str(t[2].encode()) for t in meta_info["decode_top_logprobs"][i]]
for i in range(len(meta_info["output_token_logprobs"])):
print(f"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = [str(t[2].encode()) for t in meta_info["output_top_logprobs"][i]]
for top_k in top_ks:
print(f"{top_k: <15}", end="")
print()
......
......@@ -253,14 +253,14 @@ class RuntimeEndpoint(BaseBackend):
r["meta_info"]["normalized_prompt_logprob"] for r in obj
]
decision = choices[np.argmax(normalized_prompt_logprobs)]
prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
return (
decision,
normalized_prompt_logprobs,
prefill_token_logprobs,
decode_token_logprobs,
input_token_logprobs,
output_token_logprobs,
)
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
......
......@@ -541,16 +541,16 @@ class StreamExecutor:
(
decision,
normalized_prompt_logprobs,
prefill_token_logprobs,
decode_token_logprobs,
input_token_logprobs,
output_token_logprobs,
) = self.backend.select(self, expr.choices, expr.temperature)
if expr.name is not None:
name = expr.name
self.variables[name] = decision
self.meta_info[name] = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"prefill_token_logprobs": prefill_token_logprobs,
"decode_token_logprobs": decode_token_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
}
self.variable_event[name].set()
self.text_ += decision
......
......@@ -22,13 +22,13 @@ class LogitProcessorOutput:
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor
# The logprobs of prefill tokens. shape: [#token, vocab_size]
prefill_token_logprobs: torch.Tensor
# The logprobs of input tokens. shape: [#token, vocab_size]
input_token_logprobs: torch.Tensor
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
prefill_top_logprobs: List
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
decode_top_logprobs: List
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
input_top_logprobs: List
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
output_top_logprobs: List
@dataclasses.dataclass
......@@ -58,20 +58,16 @@ class LogitsProcessor(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
def _get_normalized_prompt_logprobs(
self, prefill_token_logprobs, logits_metadata: LogitsMetadata
self, input_token_logprobs, logits_metadata: LogitsMetadata
):
logprobs_cumsum = torch.cumsum(
prefill_token_logprobs, dim=0, dtype=torch.float32
)
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
start = logits_metadata.extend_start_loc.clone()
end = start + logits_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
sum_logp = (
logprobs_cumsum[end]
- logprobs_cumsum[start]
+ prefill_token_logprobs[start]
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
)
normalized_prompt_logprobs = sum_logp / (
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
......@@ -83,34 +79,34 @@ class LogitsProcessor(nn.Module):
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
# TODO: vectorize the code below
if logits_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = []
output_top_logprobs = []
for i in range(all_logprobs.shape[0]):
k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k)
v_cpu = t.values.tolist()
p_cpu = t.indices.tolist()
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, decode_top_logprobs
output_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, output_top_logprobs
else:
prefill_top_logprobs, decode_top_logprobs = [], []
input_top_logprobs, output_top_logprobs = [], []
pt = 0
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0:
prefill_top_logprobs.append([])
decode_top_logprobs.append([])
input_top_logprobs.append([])
output_top_logprobs.append([])
continue
k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist()
prefill_top_logprobs.append(
input_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
)
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_len
return prefill_top_logprobs, decode_top_logprobs
return input_top_logprobs, output_top_logprobs
def forward(
self,
......@@ -150,9 +146,9 @@ class LogitsProcessor(nn.Module):
next_token_logits=last_logits,
next_token_logprobs=None,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
)
else:
# When logprob is requested, compute the logits for all tokens.
......@@ -164,19 +160,19 @@ class LogitsProcessor(nn.Module):
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
decode_top_logprobs = self.get_top_logprobs(
output_top_logprobs = self.get_top_logprobs(
last_logprobs, logits_metadata
)[1]
else:
decode_top_logprobs = None
output_top_logprobs = None
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=decode_top_logprobs,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=output_top_logprobs,
)
else:
all_logits = torch.matmul(hidden_states, weight.T)
......@@ -193,32 +189,32 @@ class LogitsProcessor(nn.Module):
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs(
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
all_logprobs, logits_metadata
)
else:
prefill_top_logprobs = decode_top_logprobs = None
input_top_logprobs = output_top_logprobs = None
last_logprobs = all_logprobs[last_index]
# Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation.
prefill_token_logprobs = all_logprobs[
input_token_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
prefill_token_logprobs, logits_metadata
input_token_logprobs, logits_metadata
)
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs,
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_top_logprobs=decode_top_logprobs,
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_top_logprobs=output_top_logprobs,
)
......
......@@ -226,9 +226,9 @@ class CudaGraphRunner:
next_token_logits=output.next_token_logits[:raw_bs],
next_token_logprobs=None,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
)
# Extract logprobs
......@@ -242,7 +242,7 @@ class CudaGraphRunner:
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=batch.top_logprobs_nums,
)
output.decode_top_logprobs = LogitsProcessor.get_top_logprobs(
output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
output.next_token_logprobs, logits_metadata
)[1]
......
......@@ -124,10 +124,10 @@ class Req:
self.logprob_start_len = 0
self.top_logprobs_num = 0
self.normalized_prompt_logprob = None
self.prefill_token_logprobs = None
self.prefill_top_logprobs = None
self.decode_token_logprobs = []
self.decode_top_logprobs = []
self.input_token_logprobs = None
self.input_top_logprobs = None
self.output_token_logprobs = []
self.output_top_logprobs = []
# The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs
self.last_update_decode_tokens = 0
......@@ -244,8 +244,8 @@ class Req:
k = k + 1
else:
break
self.decode_token_logprobs = self.decode_token_logprobs[:k]
self.decode_top_logprobs = self.decode_top_logprobs[:k]
self.output_token_logprobs = self.output_token_logprobs[:k]
self.output_top_logprobs = self.output_top_logprobs[:k]
self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.output_ids) - k
......
......@@ -455,7 +455,7 @@ class ModelTpServer:
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
output.input_token_logprobs = output.input_token_logprobs.tolist()
output.normalized_prompt_logprobs = (
output.normalized_prompt_logprobs.tolist()
)
......@@ -481,24 +481,24 @@ class ModelTpServer:
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
if req.prefill_token_logprobs is None:
if req.input_token_logprobs is None:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list(
req.input_token_logprobs = list(
zip(
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
req.input_ids[-req.extend_input_len + 1 :],
)
)
if req.logprob_start_len == 0:
req.prefill_token_logprobs = [
req.input_token_logprobs = [
(None, req.input_ids[0])
] + req.prefill_token_logprobs
] + req.input_token_logprobs
if req.last_update_decode_tokens != 0:
req.decode_token_logprobs.extend(
req.output_token_logprobs.extend(
list(
zip(
output.prefill_token_logprobs[
output.input_token_logprobs[
pt
+ req.extend_input_len
- req.last_update_decode_tokens : pt
......@@ -510,21 +510,21 @@ class ModelTpServer:
)
)
req.decode_token_logprobs.append(
req.output_token_logprobs.append(
(output.next_token_logprobs[i], next_token_ids[i])
)
if req.top_logprobs_num > 0:
if req.prefill_top_logprobs is None:
req.prefill_top_logprobs = output.prefill_top_logprobs[i]
if req.input_top_logprobs is None:
req.input_top_logprobs = output.input_top_logprobs[i]
if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
req.input_top_logprobs = [None] + req.input_top_logprobs
if req.last_update_decode_tokens != 0:
req.decode_top_logprobs.extend(
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
req.output_top_logprobs.extend(
output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
)
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
req.output_top_logprobs.append(output.output_top_logprobs[i])
def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
......@@ -589,11 +589,11 @@ class ModelTpServer:
req.check_finished()
if req.return_logprob:
req.decode_token_logprobs.append(
req.output_token_logprobs.append(
(next_token_logprobs[i], next_token_id)
)
if req.top_logprobs_num > 0:
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
req.output_top_logprobs.append(output.output_top_logprobs[i])
self.handle_finished_requests(batch)
......@@ -645,16 +645,16 @@ class ModelTpServer:
}
if req.return_logprob:
(
meta_info["prefill_token_logprobs"],
meta_info["decode_token_logprobs"],
meta_info["prefill_top_logprobs"],
meta_info["decode_top_logprobs"],
meta_info["input_token_logprobs"],
meta_info["output_token_logprobs"],
meta_info["input_top_logprobs"],
meta_info["output_top_logprobs"],
meta_info["normalized_prompt_logprob"],
) = (
req.prefill_token_logprobs,
req.decode_token_logprobs,
req.prefill_top_logprobs,
req.decode_top_logprobs,
req.input_token_logprobs,
req.output_token_logprobs,
req.input_top_logprobs,
req.output_top_logprobs,
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)
......
......@@ -20,7 +20,7 @@ class GenerateReqInput:
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
# The sampling_params.
# The sampling_params. See descriptions below.
sampling_params: Union[List[Dict], Dict] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
......@@ -30,7 +30,7 @@ class GenerateReqInput:
logprob_start_len: Optional[Union[List[int], int]] = None
# The number of top logprobs to return.
top_logprobs_num: Optional[Union[List[int], int]] = None
# Whether to detokenize tokens in logprobs.
# Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False
# Whether to stream output.
stream: bool = False
......
......@@ -448,23 +448,23 @@ class TokenizerManager:
return_text_in_logprobs: bool,
):
if return_logprob:
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
)
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
)
if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = (
ret["meta_info"]["input_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"],
ret["meta_info"]["input_top_logprobs"],
return_text_in_logprobs,
)
)
ret["meta_info"]["decode_top_logprobs"] = (
ret["meta_info"]["output_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
)
)
return ret
......
......@@ -54,9 +54,9 @@ class LlamaForClassification(nn.Module):
next_token_logits=scores,
next_token_logprobs=scores,
normalized_prompt_logprobs=scores,
prefill_token_logprobs=torch.ones_like(input_ids),
prefill_top_logprobs=None,
decode_top_logprobs=None,
input_token_logprobs=torch.ones_like(input_ids),
input_top_logprobs=None,
output_top_logprobs=None,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -140,29 +140,29 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if request.logprobs:
# The first chunk and echo is enabled.
if not stream_buffer and request.echo:
prefill_token_logprobs = content["meta_info"][
"prefill_token_logprobs"
input_token_logprobs = content["meta_info"][
"input_token_logprobs"
]
prefill_top_logprobs = content["meta_info"][
"prefill_top_logprobs"
input_top_logprobs = content["meta_info"][
"input_top_logprobs"
]
else:
prefill_token_logprobs = None
prefill_top_logprobs = None
input_token_logprobs = None
input_top_logprobs = None
logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=content["meta_info"][
"decode_token_logprobs"
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=content["meta_info"][
"output_token_logprobs"
][n_prev_token:],
decode_top_logprobs=content["meta_info"][
"decode_top_logprobs"
output_top_logprobs=content["meta_info"][
"output_top_logprobs"
][n_prev_token:],
)
n_prev_token = len(
content["meta_info"]["decode_token_logprobs"]
content["meta_info"]["output_token_logprobs"]
)
else:
logprobs = None
......@@ -218,17 +218,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if request.logprobs:
if request.echo:
prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"]
prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"]
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
else:
prefill_token_logprobs = None
prefill_top_logprobs = None
input_token_logprobs = None
input_top_logprobs = None
logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=ret_item["meta_info"]["decode_token_logprobs"],
decode_top_logprobs=ret_item["meta_info"]["decode_top_logprobs"],
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
)
else:
logprobs = None
......@@ -401,10 +401,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
def to_openai_style_logprobs(
prefill_token_logprobs=None,
decode_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
input_token_logprobs=None,
output_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
):
ret_logprobs = LogProbs()
......@@ -425,13 +425,13 @@ def to_openai_style_logprobs(
else:
ret_logprobs.top_logprobs.append(None)
if prefill_token_logprobs is not None:
append_token_logprobs(prefill_token_logprobs)
if decode_token_logprobs is not None:
append_token_logprobs(decode_token_logprobs)
if prefill_top_logprobs is not None:
append_top_logprobs(prefill_top_logprobs)
if decode_top_logprobs is not None:
append_top_logprobs(decode_top_logprobs)
if input_token_logprobs is not None:
append_token_logprobs(input_token_logprobs)
if output_token_logprobs is not None:
append_token_logprobs(output_token_logprobs)
if input_top_logprobs is not None:
append_top_logprobs(input_top_logprobs)
if output_top_logprobs is not None:
append_top_logprobs(output_top_logprobs)
return ret_logprobs
......@@ -13,14 +13,15 @@ import json
import requests
def test_decode(url, return_logprob, top_logprobs_num, return_text):
def test_decode(url, return_logprob=False, top_logprobs_num=0, return_text=False, n=1):
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 32,
"n": n,
},
"stream": False,
"return_logprob": return_logprob,
......@@ -41,8 +42,14 @@ if __name__ == "__main__":
url = f"{args.host}:{args.port}"
test_decode(url, False, 0, False)
test_decode(url, True, 0, False)
test_decode(url, True, 0, True)
test_decode(url, True, 3, False)
test_decode(url, True, 3, True)
test_decode(url)
test_decode(url, n=3)
for top_logprobs_num in [0, 3]:
for return_text in [True, False]:
test_decode(
url,
return_logprob=True,
top_logprobs_num=top_logprobs_num,
return_text=return_text,
)
......@@ -40,14 +40,14 @@ def test_decode_stream(url, return_logprob, top_logprobs_num):
data = json.loads(chunk[5:].strip("\n"))
if return_logprob:
assert data["meta_info"]["prefill_token_logprobs"] is not None
assert data["meta_info"]["decode_token_logprobs"] is not None
assert data["meta_info"]["input_token_logprobs"] is not None
assert data["meta_info"]["output_token_logprobs"] is not None
assert data["meta_info"]["normalized_prompt_logprob"] is not None
for logprob, token_id, token_text in data["meta_info"][
"decode_token_logprobs"
"output_token_logprobs"
][prev:]:
print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True)
prev = len(data["meta_info"]["decode_token_logprobs"])
prev = len(data["meta_info"]["output_token_logprobs"])
else:
output = data["text"].strip()
print(output[prev:], end="", flush=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment