"docs/source/vscode:/vscode.git/clone" did not exist on "a4cacf13c2faa3fe12d6ad6d8a8b6cd4b067edbf"
Commit f65c13b5 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Remove normalized_prompt_logprobs from the engine to make code easier to maintain (#2902)

parent b803b395
...@@ -251,11 +251,12 @@ class RuntimeEndpoint(BaseBackend): ...@@ -251,11 +251,12 @@ class RuntimeEndpoint(BaseBackend):
} }
obj = self._generate_http_request(s, data) obj = self._generate_http_request(s, data)
normalized_prompt_logprobs = [
r["meta_info"]["normalized_prompt_logprob"] for r in obj
]
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj] input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj] output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
normalized_prompt_logprobs = [
compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
for r in obj
]
# Remove extra token if no token healing occurred # Remove extra token if no token healing occurred
for i in range(len(input_token_logprobs)): for i in range(len(input_token_logprobs)):
...@@ -319,3 +320,8 @@ class RuntimeEndpoint(BaseBackend): ...@@ -319,3 +320,8 @@ class RuntimeEndpoint(BaseBackend):
def _assert_success(self, res): def _assert_success(self, res):
if res.status_code != 200: if res.status_code != 200:
raise RuntimeError(res.json()) raise RuntimeError(res.json())
def compute_normalized_prompt_logprobs(input_logprobs):
values = [x[0] for x in input_logprobs if x[0]]
return sum(values) / len(values)
...@@ -50,8 +50,6 @@ class LogitsProcessorOutput: ...@@ -50,8 +50,6 @@ class LogitsProcessorOutput:
next_token_top_logprobs_idx: Optional[List] = None next_token_top_logprobs_idx: Optional[List] = None
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor = None
# The logprobs of input tokens. shape: [#token] # The logprobs of input tokens. shape: [#token]
input_token_logprobs: torch.Tensor = None input_token_logprobs: torch.Tensor = None
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k] # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
...@@ -195,8 +193,6 @@ class LogitsProcessor(nn.Module): ...@@ -195,8 +193,6 @@ class LogitsProcessor(nn.Module):
else: else:
input_top_logprobs_val = input_top_logprobs_idx = None input_top_logprobs_val = input_top_logprobs_idx = None
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
input_token_logprobs = input_logprobs[ input_token_logprobs = input_logprobs[
torch.arange(input_logprobs.shape[0], device="cuda"), torch.arange(input_logprobs.shape[0], device="cuda"),
torch.cat( torch.cat(
...@@ -206,14 +202,9 @@ class LogitsProcessor(nn.Module): ...@@ -206,14 +202,9 @@ class LogitsProcessor(nn.Module):
] ]
), ),
] ]
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
input_token_logprobs,
logits_metadata,
)
return LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
normalized_prompt_logprobs=normalized_prompt_logprobs,
input_token_logprobs=input_token_logprobs, input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val, input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx, input_top_logprobs_idx=input_top_logprobs_idx,
...@@ -237,8 +228,6 @@ class LogitsProcessor(nn.Module): ...@@ -237,8 +228,6 @@ class LogitsProcessor(nn.Module):
if self.do_tensor_parallel_all_gather: if self.do_tensor_parallel_all_gather:
logits = tensor_model_parallel_all_gather(logits) logits = tensor_model_parallel_all_gather(logits)
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
logits = logits[:, : self.config.vocab_size].float() logits = logits[:, : self.config.vocab_size].float()
if self.final_logit_softcapping: if self.final_logit_softcapping:
...@@ -246,27 +235,6 @@ class LogitsProcessor(nn.Module): ...@@ -246,27 +235,6 @@ class LogitsProcessor(nn.Module):
return logits return logits
@staticmethod
def _get_normalized_prompt_logprobs(
input_token_logprobs: torch.Tensor,
logits_metadata: LogitsMetadata,
):
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
)
start = torch.zeros_like(pruned_lens)
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
end = torch.clamp(
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
)
sum_logp = (
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
)
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
return normalized_prompt_logprobs
@staticmethod @staticmethod
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata): def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
max_k = max(logits_metadata.top_logprobs_nums) max_k = max(logits_metadata.top_logprobs_nums)
......
...@@ -191,7 +191,6 @@ class DetokenizerManager: ...@@ -191,7 +191,6 @@ class DetokenizerManager:
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
output_top_logprobs_val=recv_obj.output_top_logprobs_val, output_top_logprobs_val=recv_obj.output_top_logprobs_val,
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
) )
) )
......
...@@ -340,7 +340,6 @@ class BatchTokenIDOut: ...@@ -340,7 +340,6 @@ class BatchTokenIDOut:
input_top_logprobs_idx: List[List] input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List] output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List] output_top_logprobs_idx: List[List]
normalized_prompt_logprob: List[float]
@dataclass @dataclass
...@@ -366,7 +365,6 @@ class BatchStrOut: ...@@ -366,7 +365,6 @@ class BatchStrOut:
input_top_logprobs_idx: List[List] input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List] output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List] output_top_logprobs_idx: List[List]
normalized_prompt_logprob: List[float]
@dataclass @dataclass
......
...@@ -280,7 +280,6 @@ class Req: ...@@ -280,7 +280,6 @@ class Req:
self.top_logprobs_num = top_logprobs_num self.top_logprobs_num = top_logprobs_num
# Logprobs (return value) # Logprobs (return value)
self.normalized_prompt_logprob = None
self.input_token_logprobs_val = None self.input_token_logprobs_val = None
self.input_token_logprobs_idx = None self.input_token_logprobs_idx = None
self.input_top_logprobs_val = None self.input_top_logprobs_val = None
...@@ -344,9 +343,6 @@ class Req: ...@@ -344,9 +343,6 @@ class Req:
max_prefix_len = min(max_prefix_len, input_len - 1) max_prefix_len = min(max_prefix_len, input_len - 1)
if self.return_logprob: if self.return_logprob:
if self.normalized_prompt_logprob is None:
# Need at least two tokens to compute normalized logprob
max_prefix_len = min(max_prefix_len, input_len - 2)
max_prefix_len = min(max_prefix_len, self.logprob_start_len) max_prefix_len = min(max_prefix_len, self.logprob_start_len)
max_prefix_len = max(max_prefix_len, 0) max_prefix_len = max(max_prefix_len, 0)
......
...@@ -433,7 +433,6 @@ class PrefillAdder: ...@@ -433,7 +433,6 @@ class PrefillAdder:
or input_tokens <= self.rem_chunk_tokens or input_tokens <= self.rem_chunk_tokens
or ( or (
req.return_logprob req.return_logprob
and req.normalized_prompt_logprob is None
and req.logprob_start_len != len(req.origin_input_ids) - 1 and req.logprob_start_len != len(req.origin_input_ids) - 1
) )
): ):
......
...@@ -1038,9 +1038,6 @@ class Scheduler: ...@@ -1038,9 +1038,6 @@ class Scheduler:
logits_output.input_token_logprobs = ( logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist() logits_output.input_token_logprobs.tolist()
) )
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)
# Check finish conditions # Check finish conditions
logprob_pt = 0 logprob_pt = 0
...@@ -1188,9 +1185,6 @@ class Scheduler: ...@@ -1188,9 +1185,6 @@ class Scheduler:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
if req.input_token_logprobs_val is None: if req.input_token_logprobs_val is None:
input_token_logprobs_val = output.input_token_logprobs[ input_token_logprobs_val = output.input_token_logprobs[
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
...@@ -1288,15 +1282,12 @@ class Scheduler: ...@@ -1288,15 +1282,12 @@ class Scheduler:
input_top_logprobs_idx = [] input_top_logprobs_idx = []
output_top_logprobs_val = [] output_top_logprobs_val = []
output_top_logprobs_idx = [] output_top_logprobs_idx = []
normalized_prompt_logprob = []
else: else:
input_token_logprobs_val = input_token_logprobs_idx = ( input_token_logprobs_val = input_token_logprobs_idx = (
output_token_logprobs_val output_token_logprobs_val
) = output_token_logprobs_idx = input_top_logprobs_val = ( ) = output_token_logprobs_idx = input_top_logprobs_val = (
input_top_logprobs_idx input_top_logprobs_idx
) = output_top_logprobs_val = output_top_logprobs_idx = ( ) = output_top_logprobs_val = output_top_logprobs_idx = None
normalized_prompt_logprob
) = None
for req in reqs: for req in reqs:
if req is skip_req: if req is skip_req:
...@@ -1343,7 +1334,6 @@ class Scheduler: ...@@ -1343,7 +1334,6 @@ class Scheduler:
input_top_logprobs_idx.append(req.input_top_logprobs_idx) input_top_logprobs_idx.append(req.input_top_logprobs_idx)
output_top_logprobs_val.append(req.output_top_logprobs_val) output_top_logprobs_val.append(req.output_top_logprobs_val)
output_top_logprobs_idx.append(req.output_top_logprobs_idx) output_top_logprobs_idx.append(req.output_top_logprobs_idx)
normalized_prompt_logprob.append(req.normalized_prompt_logprob)
# Send to detokenizer # Send to detokenizer
if rids: if rids:
...@@ -1370,7 +1360,6 @@ class Scheduler: ...@@ -1370,7 +1360,6 @@ class Scheduler:
input_top_logprobs_idx, input_top_logprobs_idx,
output_top_logprobs_val, output_top_logprobs_val,
output_top_logprobs_idx, output_top_logprobs_idx,
normalized_prompt_logprob,
) )
) )
else: # embedding or reward model else: # embedding or reward model
......
...@@ -796,9 +796,6 @@ class TokenizerManager: ...@@ -796,9 +796,6 @@ class TokenizerManager:
recv_obj.output_token_logprobs_idx[recv_obj_index], recv_obj.output_token_logprobs_idx[recv_obj_index],
return_text_in_logprobs, return_text_in_logprobs,
) )
meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
recv_obj_index
]
if top_logprobs_num > 0: if top_logprobs_num > 0:
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
......
...@@ -151,11 +151,6 @@ class TpModelWorkerClient: ...@@ -151,11 +151,6 @@ class TpModelWorkerClient:
logits_output.input_token_logprobs = ( logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.to("cpu", non_blocking=True) logits_output.input_token_logprobs.to("cpu", non_blocking=True)
) )
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.to(
"cpu", non_blocking=True
)
)
next_token_ids = next_token_ids.to("cpu", non_blocking=True) next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_done.record() copy_done.record()
...@@ -174,9 +169,6 @@ class TpModelWorkerClient: ...@@ -174,9 +169,6 @@ class TpModelWorkerClient:
logits_output.input_token_logprobs = ( logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist() logits_output.input_token_logprobs.tolist()
) )
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
return logits_output, next_token_ids return logits_output, next_token_ids
......
...@@ -535,7 +535,7 @@ def test_hellaswag_select(): ...@@ -535,7 +535,7 @@ def test_hellaswag_select():
# Compute accuracy # Compute accuracy
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels)) accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
assert np.abs(accuracy_gen - accuracy) < 0.01 assert np.abs(accuracy_gen - accuracy) < 0.05
assert np.abs(latency_gen - latency) < 1 assert np.abs(latency_gen - latency) < 1
return accuracy, latency return accuracy, latency
......
"""
Usage:
python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache
python3 test_httpserver_classify.py
"""
import argparse
import numpy as np
import requests
def get_logits_deprecated(url: str, prompt: str):
response = requests.post(
url + "/generate",
json={
"text": prompt,
"sampling_params": {
"max_new_tokens": 0,
},
"return_logprob": True,
},
)
return response.json()["meta_info"]["normalized_prompt_logprob"]
def get_logits_batch_deprecated(url: str, prompts: list[str]):
response = requests.post(
url + "/generate",
json={
"text": prompts,
"sampling_params": {
"max_new_tokens": 0,
},
"return_logprob": True,
},
)
ret = response.json()
logits = np.array(
list(
ret[i]["meta_info"]["normalized_prompt_logprob"]
for i in range(len(prompts))
)
)
return logits
def get_logits(url: str, prompt: str):
response = requests.post(
url + "/classify",
json={"text": prompt},
)
return response.json()["embedding"]
def get_logits_batch(url: str, prompts: list[str]):
response = requests.post(
url + "/classify",
json={"text": prompts},
)
return np.array([x["embedding"] for x in response.json()])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
# A single request
prompt = "This is a test prompt.<|eot_id|>"
logits = get_logits(url, prompt)
print(f"{logits=}")
# A batch of requests
prompts = [
"This is a test prompt.<|eot_id|>",
"This is another test prompt.<|eot_id|>",
"This is a long long long long test prompt.<|eot_id|>",
]
logits = get_logits_batch(url, prompts)
print(f"{logits=}")
...@@ -42,7 +42,6 @@ def test_decode_stream(url, return_logprob, top_logprobs_num): ...@@ -42,7 +42,6 @@ def test_decode_stream(url, return_logprob, top_logprobs_num):
if return_logprob: if return_logprob:
assert data["meta_info"]["input_token_logprobs"] is not None assert data["meta_info"]["input_token_logprobs"] is not None
assert data["meta_info"]["output_token_logprobs"] is not None assert data["meta_info"]["output_token_logprobs"] is not None
assert data["meta_info"]["normalized_prompt_logprob"] is not None
for logprob, token_id, token_text in data["meta_info"][ for logprob, token_id, token_text in data["meta_info"][
"output_token_logprobs" "output_token_logprobs"
][prev:]: ][prev:]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment