Commit f65c13b5 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Remove normalized_prompt_logprobs from the engine to make code easier to maintain (#2902)

parent b803b395
......@@ -251,11 +251,12 @@ class RuntimeEndpoint(BaseBackend):
}
obj = self._generate_http_request(s, data)
normalized_prompt_logprobs = [
r["meta_info"]["normalized_prompt_logprob"] for r in obj
]
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
normalized_prompt_logprobs = [
compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
for r in obj
]
# Remove extra token if no token healing occurred
for i in range(len(input_token_logprobs)):
......@@ -319,3 +320,8 @@ class RuntimeEndpoint(BaseBackend):
def _assert_success(self, res):
if res.status_code != 200:
raise RuntimeError(res.json())
def compute_normalized_prompt_logprobs(input_logprobs):
values = [x[0] for x in input_logprobs if x[0]]
return sum(values) / len(values)
......@@ -50,8 +50,6 @@ class LogitsProcessorOutput:
next_token_top_logprobs_idx: Optional[List] = None
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor = None
# The logprobs of input tokens. shape: [#token]
input_token_logprobs: torch.Tensor = None
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
......@@ -195,8 +193,6 @@ class LogitsProcessor(nn.Module):
else:
input_top_logprobs_val = input_top_logprobs_idx = None
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
input_token_logprobs = input_logprobs[
torch.arange(input_logprobs.shape[0], device="cuda"),
torch.cat(
......@@ -206,14 +202,9 @@ class LogitsProcessor(nn.Module):
]
),
]
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
input_token_logprobs,
logits_metadata,
)
return LogitsProcessorOutput(
next_token_logits=last_logits,
normalized_prompt_logprobs=normalized_prompt_logprobs,
input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx,
......@@ -237,8 +228,6 @@ class LogitsProcessor(nn.Module):
if self.do_tensor_parallel_all_gather:
logits = tensor_model_parallel_all_gather(logits)
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
logits = logits[:, : self.config.vocab_size].float()
if self.final_logit_softcapping:
......@@ -246,27 +235,6 @@ class LogitsProcessor(nn.Module):
return logits
@staticmethod
def _get_normalized_prompt_logprobs(
input_token_logprobs: torch.Tensor,
logits_metadata: LogitsMetadata,
):
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
)
start = torch.zeros_like(pruned_lens)
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
end = torch.clamp(
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
)
sum_logp = (
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
)
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
return normalized_prompt_logprobs
@staticmethod
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
max_k = max(logits_metadata.top_logprobs_nums)
......
......@@ -191,7 +191,6 @@ class DetokenizerManager:
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
)
)
......
......@@ -340,7 +340,6 @@ class BatchTokenIDOut:
input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]
normalized_prompt_logprob: List[float]
@dataclass
......@@ -366,7 +365,6 @@ class BatchStrOut:
input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]
normalized_prompt_logprob: List[float]
@dataclass
......
......@@ -280,7 +280,6 @@ class Req:
self.top_logprobs_num = top_logprobs_num
# Logprobs (return value)
self.normalized_prompt_logprob = None
self.input_token_logprobs_val = None
self.input_token_logprobs_idx = None
self.input_top_logprobs_val = None
......@@ -344,9 +343,6 @@ class Req:
max_prefix_len = min(max_prefix_len, input_len - 1)
if self.return_logprob:
if self.normalized_prompt_logprob is None:
# Need at least two tokens to compute normalized logprob
max_prefix_len = min(max_prefix_len, input_len - 2)
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
max_prefix_len = max(max_prefix_len, 0)
......
......@@ -433,7 +433,6 @@ class PrefillAdder:
or input_tokens <= self.rem_chunk_tokens
or (
req.return_logprob
and req.normalized_prompt_logprob is None
and req.logprob_start_len != len(req.origin_input_ids) - 1
)
):
......
......@@ -1038,9 +1038,6 @@ class Scheduler:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist()
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)
# Check finish conditions
logprob_pt = 0
......@@ -1188,9 +1185,6 @@ class Scheduler:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
if req.input_token_logprobs_val is None:
input_token_logprobs_val = output.input_token_logprobs[
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
......@@ -1288,15 +1282,12 @@ class Scheduler:
input_top_logprobs_idx = []
output_top_logprobs_val = []
output_top_logprobs_idx = []
normalized_prompt_logprob = []
else:
input_token_logprobs_val = input_token_logprobs_idx = (
output_token_logprobs_val
) = output_token_logprobs_idx = input_top_logprobs_val = (
input_top_logprobs_idx
) = output_top_logprobs_val = output_top_logprobs_idx = (
normalized_prompt_logprob
) = None
) = output_top_logprobs_val = output_top_logprobs_idx = None
for req in reqs:
if req is skip_req:
......@@ -1343,7 +1334,6 @@ class Scheduler:
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
output_top_logprobs_val.append(req.output_top_logprobs_val)
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
normalized_prompt_logprob.append(req.normalized_prompt_logprob)
# Send to detokenizer
if rids:
......@@ -1370,7 +1360,6 @@ class Scheduler:
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
normalized_prompt_logprob,
)
)
else: # embedding or reward model
......
......@@ -796,9 +796,6 @@ class TokenizerManager:
recv_obj.output_token_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
recv_obj_index
]
if top_logprobs_num > 0:
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
......
......@@ -151,11 +151,6 @@ class TpModelWorkerClient:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.to(
"cpu", non_blocking=True
)
)
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_done.record()
......@@ -174,9 +169,6 @@ class TpModelWorkerClient:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist()
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
return logits_output, next_token_ids
......
......@@ -535,7 +535,7 @@ def test_hellaswag_select():
# Compute accuracy
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
assert np.abs(accuracy_gen - accuracy) < 0.01
assert np.abs(accuracy_gen - accuracy) < 0.05
assert np.abs(latency_gen - latency) < 1
return accuracy, latency
......
"""
Usage:
python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache
python3 test_httpserver_classify.py
"""
import argparse
import numpy as np
import requests
def get_logits_deprecated(url: str, prompt: str):
response = requests.post(
url + "/generate",
json={
"text": prompt,
"sampling_params": {
"max_new_tokens": 0,
},
"return_logprob": True,
},
)
return response.json()["meta_info"]["normalized_prompt_logprob"]
def get_logits_batch_deprecated(url: str, prompts: list[str]):
response = requests.post(
url + "/generate",
json={
"text": prompts,
"sampling_params": {
"max_new_tokens": 0,
},
"return_logprob": True,
},
)
ret = response.json()
logits = np.array(
list(
ret[i]["meta_info"]["normalized_prompt_logprob"]
for i in range(len(prompts))
)
)
return logits
def get_logits(url: str, prompt: str):
response = requests.post(
url + "/classify",
json={"text": prompt},
)
return response.json()["embedding"]
def get_logits_batch(url: str, prompts: list[str]):
response = requests.post(
url + "/classify",
json={"text": prompts},
)
return np.array([x["embedding"] for x in response.json()])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
# A single request
prompt = "This is a test prompt.<|eot_id|>"
logits = get_logits(url, prompt)
print(f"{logits=}")
# A batch of requests
prompts = [
"This is a test prompt.<|eot_id|>",
"This is another test prompt.<|eot_id|>",
"This is a long long long long test prompt.<|eot_id|>",
]
logits = get_logits_batch(url, prompts)
print(f"{logits=}")
......@@ -42,7 +42,6 @@ def test_decode_stream(url, return_logprob, top_logprobs_num):
if return_logprob:
assert data["meta_info"]["input_token_logprobs"] is not None
assert data["meta_info"]["output_token_logprobs"] is not None
assert data["meta_info"]["normalized_prompt_logprob"] is not None
for logprob, token_id, token_text in data["meta_info"][
"output_token_logprobs"
][prev:]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment