Unverified Commit 21ec66e5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Minor follow-up fixes for the logprob refactor (#2670)

parent c5210dfa
...@@ -35,21 +35,21 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -35,21 +35,21 @@ from sglang.srt.model_executor.forward_batch_info import (
@dataclasses.dataclass @dataclasses.dataclass
class LogitsProcessorOutput: class LogitsProcessorOutput:
## First part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor. ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The logits of the next tokens. shape: [#seq, vocab_size] # The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor next_token_logits: torch.Tensor
# Used by speculative decoding (EAGLE) # Used by speculative decoding (EAGLE)
# The last hidden layers # The last hidden layers
hidden_states: Optional[torch.Tensor] = None hidden_states: Optional[torch.Tensor] = None
## Second part. This part will be returned by python/sglang/srt/layers/sampler.py::Sampler. ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
# The logprobs of the next tokens. shape: [#seq] # The logprobs of the next tokens. shape: [#seq]
next_token_logprobs: Optional[torch.Tensor] = None next_token_logprobs: Optional[torch.Tensor] = None
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k] # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
next_token_top_logprobs_val: Optional[List] = None next_token_top_logprobs_val: Optional[List] = None
next_token_top_logprobs_idx: Optional[List] = None next_token_top_logprobs_idx: Optional[List] = None
## Third part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor. Prefill-only. ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The normlaized logprobs of prompts. shape: [#seq] # The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor = None normalized_prompt_logprobs: torch.Tensor = None
# The logprobs of input tokens. shape: [#token] # The logprobs of input tokens. shape: [#token]
......
...@@ -56,7 +56,9 @@ class Sampler(nn.Module): ...@@ -56,7 +56,9 @@ class Sampler(nn.Module):
if global_server_args_dict["sampling_backend"] == "flashinfer": if global_server_args_dict["sampling_backend"] == "flashinfer":
if return_logprob: if return_logprob:
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems # NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
# https://github.com/flashinfer-ai/flashinfer/issues/708
# so we use the torch implementation.
logprobs = torch.log( logprobs = torch.log(
top_p_normalize_probs_torch(probs, sampling_info.top_ps) top_p_normalize_probs_torch(probs, sampling_info.top_ps)
) )
......
...@@ -36,7 +36,7 @@ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend ...@@ -36,7 +36,7 @@ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler, get_top_logprobs from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -191,10 +191,9 @@ class ModelRunner: ...@@ -191,10 +191,9 @@ class ModelRunner:
torch.get_device_module(self.device).set_device(self.gpu_id) torch.get_device_module(self.device).set_device(self.gpu_id)
if self.device == "cuda": if self.device == "cuda":
backend = "nccl" backend = "nccl"
# TODO(liangan1):Just use gloo to bypass the initilization fail
# Need to use xccl for xpu backend in the future
elif self.device == "xpu": elif self.device == "xpu":
# TODO(liangan1):Just use gloo to bypass the initilization fail
# Need to use xccl for xpu backend in the future
backend = "gloo" backend = "gloo"
elif self.device == "hpu": elif self.device == "hpu":
backend = "hccl" backend = "hccl"
......
...@@ -244,7 +244,7 @@ class SamplingBatchInfo: ...@@ -244,7 +244,7 @@ class SamplingBatchInfo:
# repetition # repetition
if self.scaling_penalties is not None: if self.scaling_penalties is not None:
logits = torch.where( logits[:] = torch.where(
logits > 0, logits > 0,
logits / self.scaling_penalties, logits / self.scaling_penalties,
logits * self.scaling_penalties, logits * self.scaling_penalties,
...@@ -253,5 +253,3 @@ class SamplingBatchInfo: ...@@ -253,5 +253,3 @@ class SamplingBatchInfo:
# Apply regex vocab_mask # Apply regex vocab_mask
if self.vocab_mask is not None: if self.vocab_mask is not None:
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask) self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
return logits
...@@ -227,7 +227,7 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -227,7 +227,7 @@ class TestSRTEndpoint(unittest.TestCase):
"regex": "( Yes| No)", "regex": "( Yes| No)",
}, },
"return_logprob": True, "return_logprob": True,
"top_logprobs_num": 5, "top_logprobs_num": 5, # The grammar constraint allows all prefix tokens so we need to use a larger top_k.
"return_text_in_logprobs": True, "return_text_in_logprobs": True,
}, },
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment