Unverified Commit 839c93bd authored by narutolhy's avatar narutolhy Committed by GitHub
Browse files

feat: add original logprobs to response (#8375)


Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
Co-authored-by: default avatarluhongyu.4869 <luhongyu.4869@bytedance.com>
parent f1e9bbaf
...@@ -61,7 +61,7 @@ class LogitsProcessorOutput: ...@@ -61,7 +61,7 @@ class LogitsProcessorOutput:
hidden_states: Optional[torch.Tensor] = None hidden_states: Optional[torch.Tensor] = None
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
# The logprobs of the next tokens. shape: [#seq] # he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
next_token_logprobs: Optional[torch.Tensor] = None next_token_logprobs: Optional[torch.Tensor] = None
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k] # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
next_token_top_logprobs_val: Optional[List] = None next_token_top_logprobs_val: Optional[List] = None
......
...@@ -27,6 +27,7 @@ if is_cuda(): ...@@ -27,6 +27,7 @@ if is_cuda():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP") SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
class Sampler(nn.Module): class Sampler(nn.Module):
...@@ -77,7 +78,12 @@ class Sampler(nn.Module): ...@@ -77,7 +78,12 @@ class Sampler(nn.Module):
batch_next_token_ids = torch.argmax(logits, -1) batch_next_token_ids = torch.argmax(logits, -1)
if return_logprob: if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1) logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else: else:
# Post process original logits. if temperatures are all 1.0, no need to rescale
if return_logprob and RETURN_ORIGINAL_LOGPROB:
logprobs = torch.softmax(logits, dim=-1)
# Post process logits # Post process logits
logits.div_(sampling_info.temperatures) logits.div_(sampling_info.temperatures)
logits[:] = torch.softmax(logits, dim=-1) logits[:] = torch.softmax(logits, dim=-1)
...@@ -116,7 +122,12 @@ class Sampler(nn.Module): ...@@ -116,7 +122,12 @@ class Sampler(nn.Module):
if return_logprob: if return_logprob:
# clamp to avoid -inf # clamp to avoid -inf
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min) if RETURN_ORIGINAL_LOGPROB:
logprobs = torch.log(logprobs).clamp(
min=torch.finfo(logprobs.dtype).min
)
else:
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
# Attach logprobs to logits_output (in-place modification) # Attach logprobs to logits_output (in-place modification)
if return_logprob: if return_logprob:
...@@ -201,7 +212,10 @@ def top_p_normalize_probs_torch( ...@@ -201,7 +212,10 @@ def top_p_normalize_probs_torch(
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort) return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): def get_top_logprobs(
logprobs: torch.Tensor,
top_logprobs_nums: List[int],
):
max_k = max(top_logprobs_nums) max_k = max(top_logprobs_nums)
ret = logprobs.topk(max_k, dim=1) ret = logprobs.topk(max_k, dim=1)
values = ret.values.tolist() values = ret.values.tolist()
...@@ -212,10 +226,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): ...@@ -212,10 +226,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
for i, k in enumerate(top_logprobs_nums): for i, k in enumerate(top_logprobs_nums):
output_top_logprobs_val.append(values[i][:k]) output_top_logprobs_val.append(values[i][:k])
output_top_logprobs_idx.append(indices[i][:k]) output_top_logprobs_idx.append(indices[i][:k])
return output_top_logprobs_val, output_top_logprobs_idx
return (
output_top_logprobs_val,
output_top_logprobs_idx,
)
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]): def get_token_ids_logprobs(
logprobs: torch.Tensor,
token_ids_logprobs: List[List[int]],
):
output_token_ids_logprobs_val = [] output_token_ids_logprobs_val = []
output_token_ids_logprobs_idx = [] output_token_ids_logprobs_idx = []
for i, token_ids in enumerate(token_ids_logprobs): for i, token_ids in enumerate(token_ids_logprobs):
...@@ -226,7 +247,10 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List ...@@ -226,7 +247,10 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
output_token_ids_logprobs_val.append([]) output_token_ids_logprobs_val.append([])
output_token_ids_logprobs_idx.append([]) output_token_ids_logprobs_idx.append([])
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx return (
output_token_ids_logprobs_val,
output_token_ids_logprobs_idx,
)
def apply_custom_logit_processor( def apply_custom_logit_processor(
......
...@@ -46,6 +46,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm ...@@ -46,6 +46,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import ( from sglang.srt.utils import (
empty_context, empty_context,
get_available_gpu_memory, get_available_gpu_memory,
get_bool_env_var,
is_cuda, is_cuda,
next_power_of_2, next_power_of_2,
) )
...@@ -54,6 +55,7 @@ if is_cuda(): ...@@ -54,6 +55,7 @@ if is_cuda():
from sgl_kernel import segment_packbits from sgl_kernel import segment_packbits
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
@contextmanager @contextmanager
...@@ -788,15 +790,20 @@ class EAGLEWorker(TpModelWorker): ...@@ -788,15 +790,20 @@ class EAGLEWorker(TpModelWorker):
token_ids_logprobs = batch.token_ids_logprobs token_ids_logprobs = batch.token_ids_logprobs
accepted_indices = res.accepted_indices accepted_indices = res.accepted_indices
assert len(accepted_indices) == len(logits_output.next_token_logits) assert len(accepted_indices) == len(logits_output.next_token_logits)
temperatures = batch.sampling_info.temperatures temperatures = batch.sampling_info.temperatures
num_draft_tokens = batch.spec_info.draft_token_num num_draft_tokens = batch.spec_info.draft_token_num
# acceptance indices are the indices in a "flattened" batch. # acceptance indices are the indices in a "flattened" batch.
# dividing it to num_draft_tokens will yield the actual batch index. # dividing it to num_draft_tokens will yield the actual batch index.
temperatures = temperatures[accepted_indices // num_draft_tokens] temperatures = temperatures[accepted_indices // num_draft_tokens]
if RETURN_ORIGINAL_LOGPROB:
logprobs = torch.nn.functional.log_softmax( logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits / temperatures, dim=-1 logits_output.next_token_logits, dim=-1
) )
else:
logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits / temperatures, dim=-1
)
batch_next_token_ids = res.verified_id batch_next_token_ids = res.verified_id
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu] num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
...@@ -813,13 +820,19 @@ class EAGLEWorker(TpModelWorker): ...@@ -813,13 +820,19 @@ class EAGLEWorker(TpModelWorker):
( (
logits_output.next_token_top_logprobs_val, logits_output.next_token_top_logprobs_val,
logits_output.next_token_top_logprobs_idx, logits_output.next_token_top_logprobs_idx,
) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved) ) = get_top_logprobs(
logprobs,
top_logprobs_nums_repeat_interleaved,
)
if any(x is not None for x in token_ids_logprobs): if any(x is not None for x in token_ids_logprobs):
( (
logits_output.next_token_token_ids_logprobs_val, logits_output.next_token_token_ids_logprobs_val,
logits_output.next_token_token_ids_logprobs_idx, logits_output.next_token_token_ids_logprobs_idx,
) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved) ) = get_token_ids_logprobs(
logprobs,
token_ids_logprobs_repeat_interleaved,
)
logits_output.next_token_logprobs = logprobs[ logits_output.next_token_logprobs = logprobs[
torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device), torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
......
...@@ -87,6 +87,7 @@ suites = { ...@@ -87,6 +87,7 @@ suites = {
TestFile("test_mla_fp8.py", 93), TestFile("test_mla_fp8.py", 93),
TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 234), TestFile("test_no_overlap_scheduler.py", 234),
TestFile("test_original_logprobs.py", 200),
TestFile("test_penalty.py", 41), TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60), TestFile("test_page_size.py", 60),
TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_pytorch_sampling_backend.py", 66),
......
"""Test original log probability alignment between SGLang and Hugging Face.
This test suite verifies the correctness of the `origin_logprobs` output (temperature=1)
and the `logprobs` output (temperature=0.5) in SGLang by comparing it against
raw logit-based probabilities computed directly from a reference Hugging Face model.
The test covers the following scenarios:
- Next-token prediction: Verifies that the log probability of the next token from
SGLang matches the Hugging Face model.
- Top-k logprobs: Ensures that the top-k original logprobs returned by SGLang are
consistent with Hugging Face outputs.
- Specified token IDs: Confirms that the original logprobs for specific token IDs
match the values computed from Hugging Face logits.
"""
import os
import random
import unittest
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import sglang as sgl
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# ------------------------- Configurable via env ------------------------- #
MODEL_ID = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
PROMPTS = [
"Hello, my name is",
"The future of AI is",
"The president of the United States is",
"The capital of France is ",
]
TOP_LOGPROBS_NUM = 50
NUM_RANDOM_TOKEN_IDS = 10
RTOL = 0.20
ATOL = 0.00
# ------------------------------------------------
torch.manual_seed(1234)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(1234)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class TestOriginalLogprob(unittest.TestCase):
def setUp(self):
# ----- HF side (float32 weights) -----
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="right")
self.hf_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.float32, device_map="auto"
)
# Shared sampling parameters
self.sampling_params = {
"temperature": 0.5, # SGLang uses 0.5, but original logprobs are used 1.0
"top_p": 1.0,
"top_k": 10,
"max_new_tokens": 1,
}
# ---------------------------------------------------------------------
# Helper: compare one SGLang block (token_logprobs / top_logprobs / ids_logprobs)
# against a reference HF log‑prob vector.
# ---------------------------------------------------------------------
def assert_logprobs_block_equal(
self,
hf_log_probs: torch.Tensor, # [V]
token_log_probs: list,
top_log_probs: list,
ids_log_probs: list,
random_token_ids: list,
tag: str = "",
):
vals, idxs, _ = zip(*token_log_probs)
sgl_vals = torch.tensor(vals, device=self.hf_model.device, dtype=torch.float32)
sgl_idxs = torch.tensor(idxs, device=self.hf_model.device, dtype=torch.long)
hf_vals = hf_log_probs[sgl_idxs]
self.assertTrue(
torch.allclose(hf_vals, sgl_vals, rtol=RTOL, atol=ATOL),
msg=f"[{tag}] token‑level mismatch at indices {sgl_idxs.tolist()}",
)
hf_topk, _ = torch.topk(hf_log_probs, k=TOP_LOGPROBS_NUM, dim=-1)
sgl_topk = torch.tensor(
[float(t[0]) for t in top_log_probs[0] if t and t[0] is not None][
:TOP_LOGPROBS_NUM
],
dtype=torch.float32,
device=self.hf_model.device,
)
k = min(hf_topk.numel(), sgl_topk.numel())
self.assertTrue(
torch.allclose(hf_topk[:k], sgl_topk[:k], rtol=RTOL, atol=ATOL),
msg=f"[{tag}] top‑k mismatch",
)
indices = torch.tensor(
random_token_ids, dtype=torch.long, device=hf_log_probs.device
)
hf_token_ids = hf_log_probs[indices]
sgl_token_ids = torch.tensor(
[v for v, _, _ in ids_log_probs[0]],
device=self.hf_model.device,
dtype=torch.float32,
)
self.assertTrue(
torch.allclose(hf_token_ids, sgl_token_ids, rtol=RTOL, atol=ATOL),
msg=f"[{tag}] token‑IDs mismatch",
)
# Optional: print max abs diff for quick diagnostics
max_diff = torch.max(torch.abs(hf_vals - sgl_vals)).item()
print(f"[{tag}] max|diff| token‑level = {max_diff:.4f}")
def test_logprob_match(self):
vocab_size = self.tokenizer.vocab_size
for env_val in ["True", "False"]:
with self.subTest(return_original_logprob=env_val):
os.environ["RETURN_ORIGINAL_LOGPROB"] = env_val
# ----- SGLang side -----
sgl_engine = sgl.Engine(
model_path=MODEL_ID,
skip_tokenizer_init=True,
trust_remote_code=True,
mem_fraction_static=0.60,
)
for prompt in PROMPTS:
random_token_ids = sorted(
random.sample(range(vocab_size), NUM_RANDOM_TOKEN_IDS)
)
enc = self.tokenizer(prompt, return_tensors="pt")
input_ids = enc["input_ids"].to(self.hf_model.device)
attn_mask = enc["attention_mask"].to(self.hf_model.device)
with torch.inference_mode():
hf_out = self.hf_model(
input_ids=input_ids,
attention_mask=attn_mask,
return_dict=True,
)
logits = hf_out.logits[:, -1, :] # [1, V]
hf_log_probs = F.log_softmax(
logits.float() / self.sampling_params["temperature"], dim=-1
)[0]
hf_original_log_probs = F.log_softmax(logits.float(), dim=-1)[0]
outputs = sgl_engine.generate(
input_ids=input_ids[0].tolist(),
sampling_params=self.sampling_params,
return_logprob=True,
top_logprobs_num=TOP_LOGPROBS_NUM,
token_ids_logprob=random_token_ids,
)
if isinstance(outputs, list):
outputs = outputs[0]
meta = outputs["meta_info"]
# Check original logprobs only if enabled
if env_val.lower() == "true":
self.assert_logprobs_block_equal(
hf_log_probs=hf_original_log_probs,
token_log_probs=meta["output_token_logprobs"],
top_log_probs=meta["output_top_logprobs"],
ids_log_probs=meta["output_token_ids_logprobs"],
random_token_ids=random_token_ids,
tag=f"Original logprobs SGLang vs HF: {prompt} ({env_val})",
)
else:
# Always check regular logprobs
self.assert_logprobs_block_equal(
hf_log_probs=hf_log_probs,
token_log_probs=meta["output_token_logprobs"],
top_log_probs=meta["output_top_logprobs"],
ids_log_probs=meta["output_token_ids_logprobs"],
random_token_ids=random_token_ids,
tag=f"logprobs SGLang vs HF: {prompt} ({env_val})",
)
sgl_engine.shutdown()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment