Unverified Commit 9c6ba248 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Refactor logprob computation to return the real logprob used in sampling (#2664)

parent b02da24a
...@@ -17,6 +17,8 @@ import dataclasses ...@@ -17,6 +17,8 @@ import dataclasses
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
import triton
import triton.language as tl
from torch import nn from torch import nn
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -33,51 +35,55 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -33,51 +35,55 @@ from sglang.srt.model_executor.forward_batch_info import (
@dataclasses.dataclass @dataclasses.dataclass
class LogitsProcessorOutput: class LogitsProcessorOutput:
## First part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor.
# The logits of the next tokens. shape: [#seq, vocab_size] # The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor next_token_logits: torch.Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size] # Used by speculative decoding (EAGLE)
next_token_logprobs: torch.Tensor = None # The last hidden layers
hidden_states: Optional[torch.Tensor] = None
## Second part. This part will be returned by python/sglang/srt/layers/sampler.py::Sampler.
# The logprobs of the next tokens. shape: [#seq]
next_token_logprobs: Optional[torch.Tensor] = None
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
next_token_top_logprobs_val: Optional[List] = None
next_token_top_logprobs_idx: Optional[List] = None
## Third part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor. Prefill-only.
# The normlaized logprobs of prompts. shape: [#seq] # The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor = None normalized_prompt_logprobs: torch.Tensor = None
# The logprobs of input tokens. shape: [#token, vocab_size] # The logprobs of input tokens. shape: [#token]
input_token_logprobs: torch.Tensor = None input_token_logprobs: torch.Tensor = None
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
input_top_logprobs_val: List = None input_top_logprobs_val: List = None
input_top_logprobs_idx: List = None input_top_logprobs_idx: List = None
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
output_top_logprobs_val: List = None
output_top_logprobs_idx: List = None
# Used by speculative decoding (EAGLE)
# The output of transformer layers
hidden_states: Optional[torch.Tensor] = None
@dataclasses.dataclass @dataclasses.dataclass
class LogitsMetadata: class LogitsMetadata:
forward_mode: ForwardMode forward_mode: ForwardMode
top_logprobs_nums: Optional[List[int]] capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
return_logprob: bool = False
return_top_logprob: bool = False
extend_return_logprob: bool = False
extend_return_top_logprob: bool = False
extend_seq_lens: Optional[torch.Tensor] = None extend_seq_lens: Optional[torch.Tensor] = None
extend_seq_lens_cpu: Optional[List[int]] = None extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
top_logprobs_nums: Optional[List[int]] = None
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
@classmethod @classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch): def from_forward_batch(cls, forward_batch: ForwardBatch):
extend_logprob_pruned_lens_cpu = None if forward_batch.spec_info:
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
else:
capture_hidden_mode = CaptureHiddenMode.NULL
if forward_batch.return_logprob: if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) extend_return_logprob = True
if forward_batch.forward_mode.is_extend(): extend_return_top_logprob = any(
x > 0 for x in forward_batch.top_logprobs_nums
)
extend_logprob_pruned_lens_cpu = [ extend_logprob_pruned_lens_cpu = [
extend_len - start_len extend_len - start_len
for extend_len, start_len in zip( for extend_len, start_len in zip(
...@@ -86,23 +92,20 @@ class LogitsMetadata: ...@@ -86,23 +92,20 @@ class LogitsMetadata:
) )
] ]
else: else:
return_top_logprob = False extend_return_logprob = extend_return_top_logprob = (
extend_logprob_pruned_lens_cpu
if forward_batch.spec_info: ) = False
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
else:
capture_hidden_mode = CaptureHiddenMode.NULL
return cls( return cls(
forward_mode=forward_batch.forward_mode, forward_mode=forward_batch.forward_mode,
top_logprobs_nums=forward_batch.top_logprobs_nums, capture_hidden_mode=capture_hidden_mode,
return_logprob=forward_batch.return_logprob, extend_return_logprob=extend_return_logprob,
return_top_logprob=return_top_logprob, extend_return_top_logprob=extend_return_top_logprob,
extend_seq_lens=forward_batch.extend_seq_lens, extend_seq_lens=forward_batch.extend_seq_lens,
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu, extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu, extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
capture_hidden_mode=capture_hidden_mode, top_logprobs_nums=forward_batch.top_logprobs_nums,
) )
...@@ -129,7 +132,6 @@ class LogitsProcessor(nn.Module): ...@@ -129,7 +132,6 @@ class LogitsProcessor(nn.Module):
): ):
if isinstance(logits_metadata, ForwardBatch): if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
assert isinstance(logits_metadata, LogitsMetadata)
# Get the last hidden states and last logits for the next token prediction # Get the last hidden states and last logits for the next token prediction
if ( if (
...@@ -142,18 +144,10 @@ class LogitsProcessor(nn.Module): ...@@ -142,18 +144,10 @@ class LogitsProcessor(nn.Module):
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
last_hidden = hidden_states[last_index] last_hidden = hidden_states[last_index]
# Compute logits
last_logits = self._get_logits(last_hidden, lm_head) last_logits = self._get_logits(last_hidden, lm_head)
if self.do_tensor_parallel_all_gather: if not logits_metadata.extend_return_logprob:
last_logits = tensor_model_parallel_all_gather(last_logits) # Decode mode or extend mode without return_logprob.
last_logits = last_logits[:, : self.config.vocab_size].float()
if self.final_logit_softcapping:
last_logits.div_(self.final_logit_softcapping)
torch.tanh(last_logits, out=last_logits)
last_logits.mul_(self.final_logit_softcapping)
# Return only last_logits if logprob is not requested
if not logits_metadata.return_logprob:
return LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
hidden_states=( hidden_states=(
...@@ -166,74 +160,42 @@ class LogitsProcessor(nn.Module): ...@@ -166,74 +160,42 @@ class LogitsProcessor(nn.Module):
) )
), ),
) )
else:
last_logprobs = self.compute_temp_top_p_normalized_logprobs(
last_logits, logits_metadata
)
if logits_metadata.forward_mode.is_decode():
if logits_metadata.return_top_logprob:
output_top_logprobs_val, output_top_logprobs_idx = (
self.get_top_logprobs(last_logprobs, logits_metadata)[2:4]
)
else:
output_top_logprobs_val = output_top_logprobs_idx = None
return LogitsProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
output_top_logprobs_val=output_top_logprobs_val,
output_top_logprobs_idx=output_top_logprobs_idx,
)
else: else:
# Slice the requested tokens to compute logprob # Slice the requested tokens to compute logprob
pt, states, pruned_input_ids = 0, [], [] pt, pruned_states, pruned_input_ids = 0, [], []
for start_len, extend_len in zip( for start_len, extend_len in zip(
logits_metadata.extend_logprob_start_lens_cpu, logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu, logits_metadata.extend_seq_lens_cpu,
): ):
states.append(hidden_states[pt + start_len : pt + extend_len]) pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
pt += extend_len pt += extend_len
# Compute the logits and logprobs for all required tokens # Compute the logits of all required tokens
states = torch.cat(states, dim=0) pruned_states = torch.cat(pruned_states)
all_logits = self._get_logits(states, lm_head) del hidden_states
if self.do_tensor_parallel_all_gather: input_token_logits = self._get_logits(pruned_states, lm_head)
all_logits = tensor_model_parallel_all_gather(all_logits) del pruned_states
# The LM head's weights may be zero-padded for parallelism. Remove any
# extra logits that this padding may have produced.
all_logits = all_logits[:, : self.config.vocab_size].float()
if self.final_logit_softcapping:
all_logits.div_(self.final_logit_softcapping)
torch.tanh(all_logits, out=all_logits)
all_logits.mul_(self.final_logit_softcapping)
all_logprobs = all_logits
del all_logits, hidden_states
all_logprobs = self.compute_temp_top_p_normalized_logprobs( # Normalize the logprob w/o temperature, top-p
all_logprobs, logits_metadata input_logprobs = input_token_logits
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
input_logprobs, logits_metadata
) )
# Get the logprob of top-k tokens # Get the logprob of top-k tokens
if logits_metadata.return_top_logprob: if logits_metadata.extend_return_top_logprob:
( (
input_top_logprobs_val, input_top_logprobs_val,
input_top_logprobs_idx, input_top_logprobs_idx,
output_top_logprobs_val, ) = self.get_top_logprobs(input_logprobs, logits_metadata)
output_top_logprobs_idx,
) = self.get_top_logprobs(all_logprobs, logits_metadata)
else: else:
input_top_logprobs_val = input_top_logprobs_idx = ( input_top_logprobs_val = input_top_logprobs_idx = None
output_top_logprobs_val
) = output_top_logprobs_idx = None
# Compute the normalized logprobs for the requested tokens. # Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching. # Note that we pad a zero at the end for easy batching.
input_token_logprobs = all_logprobs[ input_token_logprobs = input_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"), torch.arange(input_logprobs.shape[0], device="cuda"),
torch.cat( torch.cat(
[ [
torch.cat(pruned_input_ids)[1:], torch.cat(pruned_input_ids)[1:],
...@@ -248,13 +210,10 @@ class LogitsProcessor(nn.Module): ...@@ -248,13 +210,10 @@ class LogitsProcessor(nn.Module):
return LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs,
input_token_logprobs=input_token_logprobs, input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val, input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx, input_top_logprobs_idx=input_top_logprobs_idx,
output_top_logprobs_val=output_top_logprobs_val,
output_top_logprobs_idx=output_top_logprobs_idx,
) )
def _get_logits( def _get_logits(
...@@ -269,9 +228,19 @@ class LogitsProcessor(nn.Module): ...@@ -269,9 +228,19 @@ class LogitsProcessor(nn.Module):
# GGUF models # GGUF models
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
# Optional scaling factor
if self.logit_scale is not None: if self.logit_scale is not None:
logits.mul_(self.logit_scale) # In-place multiply logits.mul_(self.logit_scale)
if self.do_tensor_parallel_all_gather:
logits = tensor_model_parallel_all_gather(logits)
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
logits = logits[:, : self.config.vocab_size].float()
if self.final_logit_softcapping:
fused_softcap(logits, self.final_logit_softcapping)
return logits return logits
@staticmethod @staticmethod
...@@ -302,16 +271,7 @@ class LogitsProcessor(nn.Module): ...@@ -302,16 +271,7 @@ class LogitsProcessor(nn.Module):
values = ret.values.tolist() values = ret.values.tolist()
indices = ret.indices.tolist() indices = ret.indices.tolist()
if logits_metadata.forward_mode.is_decode():
output_top_logprobs_val = []
output_top_logprobs_idx = []
for i, k in enumerate(logits_metadata.top_logprobs_nums):
output_top_logprobs_val.append(values[i][:k])
output_top_logprobs_idx.append(indices[i][:k])
return None, None, output_top_logprobs_val, output_top_logprobs_idx
else:
input_top_logprobs_val, input_top_logprobs_idx = [], [] input_top_logprobs_val, input_top_logprobs_idx = [], []
output_top_logprobs_val, output_top_logprobs_idx = [], []
pt = 0 pt = 0
for k, pruned_len in zip( for k, pruned_len in zip(
...@@ -321,8 +281,6 @@ class LogitsProcessor(nn.Module): ...@@ -321,8 +281,6 @@ class LogitsProcessor(nn.Module):
if pruned_len <= 0: if pruned_len <= 0:
input_top_logprobs_val.append([]) input_top_logprobs_val.append([])
input_top_logprobs_idx.append([]) input_top_logprobs_idx.append([])
output_top_logprobs_val.append([])
output_top_logprobs_idx.append([])
continue continue
input_top_logprobs_val.append( input_top_logprobs_val.append(
...@@ -331,61 +289,55 @@ class LogitsProcessor(nn.Module): ...@@ -331,61 +289,55 @@ class LogitsProcessor(nn.Module):
input_top_logprobs_idx.append( input_top_logprobs_idx.append(
[indices[pt + j][:k] for j in range(pruned_len - 1)] [indices[pt + j][:k] for j in range(pruned_len - 1)]
) )
output_top_logprobs_val.append(
list(
values[pt + pruned_len - 1][:k],
)
)
output_top_logprobs_idx.append(
list(
indices[pt + pruned_len - 1][:k],
)
)
pt += pruned_len pt += pruned_len
return ( return input_top_logprobs_val, input_top_logprobs_idx
input_top_logprobs_val,
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
)
@staticmethod @staticmethod
def compute_temp_top_p_normalized_logprobs( def compute_temp_top_p_normalized_logprobs(
last_logits: torch.Tensor, logits_metadata: LogitsMetadata last_logits: torch.Tensor, logits_metadata: LogitsMetadata
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: Implement the temp and top-p normalization
return torch.nn.functional.log_softmax(last_logits, dim=-1) return torch.nn.functional.log_softmax(last_logits, dim=-1)
def test(): @triton.jit
all_logprobs = torch.tensor( def fused_softcap_kernel(
# s s s full_logits_ptr,
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]], softcapping_value,
dtype=torch.float32, n_elements,
device="cuda", BLOCK_SIZE: tl.constexpr,
) ):
seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda") pid = tl.program_id(0)
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda") block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
token_logprobs = all_logprobs[ # Load values
torch.arange(all_logprobs.shape[0], device="cuda"), x = tl.load(full_logits_ptr + offsets, mask=mask)
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
len_cumsum = torch.cumsum(seq_lens, dim=0) # Perform operations in-place
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0) x = x / softcapping_value
end = start + seq_lens - 2
start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
# assert logprobs == [2, _, 2, 4, _] # Manual tanh implementation using exp
print("token logprobs", token_logprobs) exp2x = tl.exp(2 * x)
print("start", start) x = (exp2x - 1) / (exp2x + 1)
print("end", end)
print("sum_logp", sum_logp)
x = x * softcapping_value
if __name__ == "__main__": # Store result
test() tl.store(full_logits_ptr + offsets, x, mask=mask)
def fused_softcap(full_logits, final_logit_softcapping):
n_elements = full_logits.numel()
BLOCK_SIZE = 1024
grid = ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE, 1, 1)
fused_softcap_kernel[grid](
full_logits_ptr=full_logits,
softcapping_value=final_logit_softcapping,
n_elements=n_elements,
BLOCK_SIZE=BLOCK_SIZE,
)
return full_logits
import logging import logging
from typing import Union from typing import List
import torch import torch
from torch import nn from torch import nn
...@@ -28,13 +28,12 @@ class Sampler(nn.Module): ...@@ -28,13 +28,12 @@ class Sampler(nn.Module):
def forward( def forward(
self, self,
logits: Union[torch.Tensor, LogitsProcessorOutput], logits_output: LogitsProcessorOutput,
sampling_info: SamplingBatchInfo, sampling_info: SamplingBatchInfo,
return_logprob: bool,
top_logprobs_nums: List[int],
): ):
if isinstance(logits, LogitsProcessorOutput): logits = logits_output.next_token_logits
logits = logits.next_token_logits
logits = logits.contiguous()
if self.use_nan_detectioin and torch.any(torch.isnan(logits)): if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.") logger.warning("Detected errors during sampling! NaN in the logits.")
...@@ -47,6 +46,8 @@ class Sampler(nn.Module): ...@@ -47,6 +46,8 @@ class Sampler(nn.Module):
if sampling_info.is_all_greedy: if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling # Use torch.argmax if all requests use greedy sampling
batch_next_token_ids = torch.argmax(logits, -1) batch_next_token_ids = torch.argmax(logits, -1)
if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else: else:
# Post process logits # Post process logits
logits.div_(sampling_info.temperatures) logits.div_(sampling_info.temperatures)
...@@ -54,6 +55,12 @@ class Sampler(nn.Module): ...@@ -54,6 +55,12 @@ class Sampler(nn.Module):
del logits del logits
if global_server_args_dict["sampling_backend"] == "flashinfer": if global_server_args_dict["sampling_backend"] == "flashinfer":
if return_logprob:
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems
logprobs = torch.log(
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
)
max_top_k_round, batch_size = 32, probs.shape[0] max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand( uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device (max_top_k_round, batch_size), device=probs.device
...@@ -76,6 +83,7 @@ class Sampler(nn.Module): ...@@ -76,6 +83,7 @@ class Sampler(nn.Module):
if self.use_nan_detectioin and not torch.all(success): if self.use_nan_detectioin and not torch.all(success):
logger.warning("Detected errors during sampling!") logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids) batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
elif global_server_args_dict["sampling_backend"] == "pytorch": elif global_server_args_dict["sampling_backend"] == "pytorch":
# A slower fallback implementation with torch native operations. # A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
...@@ -85,12 +93,31 @@ class Sampler(nn.Module): ...@@ -85,12 +93,31 @@ class Sampler(nn.Module):
sampling_info.min_ps, sampling_info.min_ps,
sampling_info.need_min_p_sampling, sampling_info.need_min_p_sampling,
) )
if return_logprob:
logprobs = torch.log(
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
)
else: else:
raise ValueError( raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
) )
return batch_next_token_ids.to(torch.int32) batch_next_token_ids = batch_next_token_ids.to(torch.int32)
# Attach logprobs to logits_output (in-place modification)
if return_logprob:
if any(x > 0 for x in top_logprobs_nums):
(
logits_output.next_token_top_logprobs_val,
logits_output.next_token_top_logprobs_idx,
) = get_top_logprobs(logprobs, top_logprobs_nums)
logits_output.next_token_logprobs = logprobs[
torch.arange(len(batch_next_token_ids), device=sampling_info.device),
batch_next_token_ids,
]
return batch_next_token_ids
def top_k_top_p_min_p_sampling_from_probs_torch( def top_k_top_p_min_p_sampling_from_probs_torch(
...@@ -120,20 +147,27 @@ def top_k_top_p_min_p_sampling_from_probs_torch( ...@@ -120,20 +147,27 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
return batch_next_token_ids return batch_next_token_ids
def top_p_normalize_probs( def top_p_normalize_probs_torch(
probs: torch.Tensor, probs: torch.Tensor,
top_ps: torch.Tensor, top_ps: torch.Tensor,
): ):
if global_server_args_dict["sampling_backend"] == "flashinfer":
return top_p_renorm_prob(probs, top_ps)
elif global_server_args_dict["sampling_backend"] == "pytorch":
# See also top_k_top_p_min_p_sampling_from_probs_torch # See also top_k_top_p_min_p_sampling_from_probs_torch
probs_sort, probs_idx = probs.sort(dim=-1, descending=True) probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1) probs_sum = torch.cumsum(probs_sort, dim=-1)
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort) return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
) max_k = max(top_logprobs_nums)
ret = logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
output_top_logprobs_val = []
output_top_logprobs_idx = []
for i, k in enumerate(top_logprobs_nums):
output_top_logprobs_val.append(values[i][:k])
output_top_logprobs_idx.append(indices[i][:k])
return output_top_logprobs_val, output_top_logprobs_idx
...@@ -974,12 +974,10 @@ class Scheduler: ...@@ -974,12 +974,10 @@ class Scheduler:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
else: else:
# Move next_token_ids and logprobs to cpu # Move next_token_ids and logprobs to cpu
next_token_ids = next_token_ids.tolist()
if batch.return_logprob: if batch.return_logprob:
logits_output.next_token_logprobs = ( logits_output.next_token_logprobs = (
logits_output.next_token_logprobs[ logits_output.next_token_logprobs.tolist()
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].tolist()
) )
logits_output.input_token_logprobs = ( logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist() logits_output.input_token_logprobs.tolist()
...@@ -987,7 +985,6 @@ class Scheduler: ...@@ -987,7 +985,6 @@ class Scheduler:
logits_output.normalized_prompt_logprobs = ( logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist() logits_output.normalized_prompt_logprobs.tolist()
) )
next_token_ids = next_token_ids.tolist()
# Check finish conditions # Check finish conditions
logprob_pt = 0 logprob_pt = 0
...@@ -1064,13 +1061,9 @@ class Scheduler: ...@@ -1064,13 +1061,9 @@ class Scheduler:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
next_token_logprobs = logits_output.next_token_logprobs next_token_logprobs = logits_output.next_token_logprobs
else: else:
# Move next_token_ids and logprobs to cpu
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs.tolist()
self.token_to_kv_pool.free_group_begin() self.token_to_kv_pool.free_group_begin()
...@@ -1095,10 +1088,10 @@ class Scheduler: ...@@ -1095,10 +1088,10 @@ class Scheduler:
req.output_token_logprobs_idx.append(next_token_id) req.output_token_logprobs_idx.append(next_token_id)
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
req.output_top_logprobs_val.append( req.output_top_logprobs_val.append(
logits_output.output_top_logprobs_val[i] logits_output.next_token_top_logprobs_val[i]
) )
req.output_top_logprobs_idx.append( req.output_top_logprobs_idx.append(
logits_output.output_top_logprobs_idx[i] logits_output.next_token_top_logprobs_idx[i]
) )
if req.grammar is not None: if req.grammar is not None:
...@@ -1200,8 +1193,9 @@ class Scheduler: ...@@ -1200,8 +1193,9 @@ class Scheduler:
req.output_top_logprobs_idx.extend( req.output_top_logprobs_idx.extend(
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :] output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
) )
req.output_top_logprobs_val.append(output.output_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i]) req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
return num_input_logprobs return num_input_logprobs
......
...@@ -144,10 +144,9 @@ class TpModelWorkerClient: ...@@ -144,10 +144,9 @@ class TpModelWorkerClient:
# Copy results to the CPU # Copy results to the CPU
if model_worker_batch.return_logprob: if model_worker_batch.return_logprob:
logits_output.next_token_logprobs = logits_output.next_token_logprobs[ logits_output.next_token_logprobs = (
torch.arange(len(next_token_ids), device=self.device), logits_output.next_token_logprobs.to("cpu", non_blocking=True)
next_token_ids, )
].to("cpu", non_blocking=True)
if logits_output.input_token_logprobs is not None: if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = ( logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.to("cpu", non_blocking=True) logits_output.input_token_logprobs.to("cpu", non_blocking=True)
......
...@@ -392,34 +392,7 @@ class CudaGraphRunner: ...@@ -392,34 +392,7 @@ class CudaGraphRunner:
self.graphs[bs].replay() self.graphs[bs].replay()
next_token_logits = self.output_buffers[bs][:raw_bs] next_token_logits = self.output_buffers[bs][:raw_bs]
# Extract logprobs
if forward_batch.return_logprob:
logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=forward_batch.top_logprobs_nums,
)
next_token_logprobs = (
LogitsProcessor.compute_temp_top_p_normalized_logprobs(
next_token_logits, logits_metadata
)
)
logits_output = LogitsProcessorOutput( logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits, next_token_logits=next_token_logits,
next_token_logprobs=next_token_logprobs,
) )
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if return_top_logprob:
(
logits_output.output_top_logprobs_val,
logits_output.output_top_logprobs_idx,
) = LogitsProcessor.get_top_logprobs(
next_token_logprobs, logits_metadata
)[
2:4
]
else:
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
)
return logits_output return logits_output
...@@ -36,7 +36,7 @@ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend ...@@ -36,7 +36,7 @@ from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.sampler import Sampler, get_top_logprobs
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -48,7 +48,6 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -48,7 +48,6 @@ from sglang.srt.mem_cache.memory_pool import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
enable_show_time_cost, enable_show_time_cost,
...@@ -192,7 +191,8 @@ class ModelRunner: ...@@ -192,7 +191,8 @@ class ModelRunner:
torch.get_device_module(self.device).set_device(self.gpu_id) torch.get_device_module(self.device).set_device(self.gpu_id)
if self.device == "cuda": if self.device == "cuda":
backend = "nccl" backend = "nccl"
# ToDO(liangan1):Just use gloo to bypass the initilization fail
# TODO(liangan1):Just use gloo to bypass the initilization fail
# Need to use xccl for xpu backend in the future # Need to use xccl for xpu backend in the future
elif self.device == "xpu": elif self.device == "xpu":
backend = "gloo" backend = "gloo"
...@@ -704,6 +704,7 @@ class ModelRunner: ...@@ -704,6 +704,7 @@ class ModelRunner:
def sample( def sample(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
) -> torch.Tensor: ) -> torch.Tensor:
# Apply logit bias
sampling_info = forward_batch.sampling_info sampling_info = forward_batch.sampling_info
if sampling_info.sampling_info_done: if sampling_info.sampling_info_done:
# Overlap mode: the function update_regex_vocab_mask was executed # Overlap mode: the function update_regex_vocab_mask was executed
...@@ -714,34 +715,16 @@ class ModelRunner: ...@@ -714,34 +715,16 @@ class ModelRunner:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass. # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info.update_regex_vocab_mask() sampling_info.update_regex_vocab_mask()
sampling_info.update_penalties() sampling_info.update_penalties()
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info) sampling_info.apply_logits_bias(logits_output.next_token_logits)
# Sample the next tokens.
next_token_ids = self.sampler(logits, sampling_info)
return next_token_ids
def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# Apply logit_bias
if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias)
# min-token, presence, frequency # Sample the next tokens
if sampling_info.linear_penalties is not None: next_token_ids = self.sampler(
logits.add_(sampling_info.linear_penalties) logits_output,
sampling_info,
# repetition forward_batch.return_logprob,
if sampling_info.scaling_penalties is not None: forward_batch.top_logprobs_nums,
logits = torch.where(
logits > 0,
logits / sampling_info.scaling_penalties,
logits * sampling_info.scaling_penalties,
) )
return next_token_ids
# Apply regex vocab_mask
if sampling_info.vocab_mask is not None:
sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
return logits
@property @property
def model_is_mrope(self) -> bool: def model_is_mrope(self) -> bool:
......
...@@ -232,3 +232,26 @@ class SamplingBatchInfo: ...@@ -232,3 +232,26 @@ class SamplingBatchInfo:
self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device self.logit_bias, other.logit_bias, len(self), len(other), self.device
) )
def apply_logits_bias(self, logits: torch.Tensor):
# Apply logit_bias
if self.logit_bias is not None:
logits.add_(self.logit_bias)
# min-token, presence, frequency
if self.linear_penalties is not None:
logits.add_(self.linear_penalties)
# repetition
if self.scaling_penalties is not None:
logits = torch.where(
logits > 0,
logits / self.scaling_penalties,
logits * self.scaling_penalties,
)
# Apply regex vocab_mask
if self.vocab_mask is not None:
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
return logits
...@@ -6,7 +6,7 @@ import requests ...@@ -6,7 +6,7 @@ import requests
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
popen_launch_server, popen_launch_server,
...@@ -17,7 +17,7 @@ class TestBatchPenalizerE2E(unittest.TestCase): ...@@ -17,7 +17,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
......
...@@ -213,6 +213,41 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -213,6 +213,41 @@ class TestSRTEndpoint(unittest.TestCase):
max_diff = np.max(diff) max_diff = np.max(diff)
self.assertLess(max_diff, 0.25) self.assertLess(max_diff, 0.25)
def test_logprob_grammar(self):
prompts = "Question: Is Paris the Capital of France? Answer:"
allowed_tokens = [" Yes", " No"]
response = requests.post(
self.base_url + "/generate",
json={
"text": prompts,
"sampling_params": {
"temperature": 1.0,
"max_new_tokens": 1,
"regex": "( Yes| No)",
},
"return_logprob": True,
"top_logprobs_num": 5,
"return_text_in_logprobs": True,
},
)
response_json = response.json()
output_top_logprobs = response_json["meta_info"]["output_top_logprobs"][0]
print(f"{output_top_logprobs=}")
# Parse results
# This is becaues the grammar constraint allows all prefix tokens
logprobs = [None] * 2
for i in range(len(output_top_logprobs)):
try:
idx = allowed_tokens.index(output_top_logprobs[i][2])
except ValueError:
# Not found
continue
logprobs[idx] = output_top_logprobs[i][0]
self.assertTrue(all(x is not None for x in logprobs))
def test_get_server_info(self): def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info") response = requests.get(self.base_url + "/get_server_info")
response_json = response.json() response_json = response.json()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment