Unverified Commit d5fa019c authored by Zhao Chen's avatar Zhao Chen Committed by GitHub
Browse files

feat: limit peak memory usage when computing logprobs (#6318)


Signed-off-by: default avatarZhao Chen <zhaochen.zju@gmail.com>
Co-authored-by: default avatar赵晨阳 <zhaochen20@outlook.com>
parent fef3a6b6
...@@ -273,6 +273,10 @@ class Envs: ...@@ -273,6 +273,10 @@ class Envs:
# Sparse Embeddings # Sparse Embeddings
SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None) SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None)
# Logits processor
SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK = EnvBool(False)
SGLANG_LOGITS_PROCESSER_CHUNK_SIZE = EnvInt(2048)
# Tool-Call behavior # Tool-Call behavior
SGLANG_TOOL_STRICT_LEVEL = EnvInt(ToolStrictLevel.OFF) SGLANG_TOOL_STRICT_LEVEL = EnvInt(ToolStrictLevel.OFF)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import dataclasses import dataclasses
import logging import logging
from typing import List, Optional, Union from typing import List, Optional, Tuple, Union
import torch import torch
import triton import triton
...@@ -26,6 +26,7 @@ from sglang.srt.distributed import ( ...@@ -26,6 +26,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.environ import envs
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
DpPaddingMode, DpPaddingMode,
attn_tp_all_gather, attn_tp_all_gather,
...@@ -53,6 +54,15 @@ logger = logging.getLogger(__name__) ...@@ -53,6 +54,15 @@ logger = logging.getLogger(__name__)
_is_npu = is_npu() _is_npu = is_npu()
@dataclasses.dataclass
class InputLogprobsResult:
input_token_logprobs: torch.Tensor
input_top_logprobs_val: Optional[List] = None
input_top_logprobs_idx: Optional[List] = None
input_token_ids_logprobs_val: Optional[List] = None
input_token_ids_logprobs_idx: Optional[List] = None
@dataclasses.dataclass @dataclasses.dataclass
class LogitsProcessorOutput: class LogitsProcessorOutput:
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
...@@ -248,6 +258,11 @@ class LogitsProcessor(nn.Module): ...@@ -248,6 +258,11 @@ class LogitsProcessor(nn.Module):
): ):
self.final_logit_softcapping = None self.final_logit_softcapping = None
# enable chunked logprobs processing
self.enable_logprobs_chunk = envs.SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK.value
# chunk size for logprobs processing
self.logprobs_chunk_size = envs.SGLANG_LOGITS_PROCESSER_CHUNK_SIZE.value
def compute_logprobs_for_multi_item_scoring( def compute_logprobs_for_multi_item_scoring(
self, self,
input_ids, input_ids,
...@@ -405,19 +420,41 @@ class LogitsProcessor(nn.Module): ...@@ -405,19 +420,41 @@ class LogitsProcessor(nn.Module):
sample_indices = None sample_indices = None
input_logprob_indices = None input_logprob_indices = None
else: else:
# Input logprobs are required. # Prefill with input logprobs.
# Find 3 different indices. # Find 4 different indices.
# 1. pruned_states: hidden states that we want logprobs from. # 1. pruned_states: hidden states that we want logprobs from.
# 2. sample_indices: Indices that have sampled tokens. # 2. sample_indices: Indices that have sampled tokens.
# 3. input_logprob_indices: Indices that have input logprob tokens. # 3. input_logprob_indices: Indices that have input logprob tokens.
# 4. token_to_seq_idx: map each token to its sequence index
#
# Example
# -------
# Suppose a batch (flattened by sequence):
# [t00, t01, t02, t03, t10, t11, t12, t13, t14, t20, t21, t22, t23, t24, t25]
# extend_seq_lens_cpu = [4, 5, 6]
# extend_logprob_start_lens_cpu = [0, 5, 3]
#
# Then, the indices are:
# pruned_states -> [t00, t01, t02, t03, t14, t23, t24, t25]
# sample_indices -> [3, 4, 7]
# input_logprob_indices -> [0, 1, 2, 3, 5, 6, 7]
# token_to_seq_idx -> [0, 0, 0, 0, 1, 2, 2, 2]
#
# If chunk is enabled and chunk_size = 3, the chunks will be computed in a chunked manner:
# [t00, t01, t02], [t03, t14, t23], [t24, t25]
sample_index_pt = -1 sample_index_pt = -1
sample_indices = [] sample_indices = []
input_logprob_indices_pt = 0 input_logprob_indices_pt = 0
input_logprob_indices = [] input_logprob_indices = []
pt, pruned_states = 0, [] pt, pruned_states = 0, []
for extend_logprob_start_len, extend_len in zip( token_to_seq_idx = []
logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu, for idx, (extend_logprob_start_len, extend_len) in enumerate(
zip(
logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu,
)
): ):
# It can happen in chunked prefill. We still need to sample 1 token, # It can happen in chunked prefill. We still need to sample 1 token,
# But we don't want to include it in input logprob. # But we don't want to include it in input logprob.
...@@ -430,6 +467,9 @@ class LogitsProcessor(nn.Module): ...@@ -430,6 +467,9 @@ class LogitsProcessor(nn.Module):
# by a caller. # by a caller.
assert extend_len > start_len assert extend_len > start_len
pruned_states.append(hidden_states[pt + start_len : pt + extend_len]) pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
# Map each token to its sequence index, for chunked computation
# of input logprobs
token_to_seq_idx.extend([idx] * (extend_len - start_len))
pt += extend_len pt += extend_len
sample_index_pt += extend_len - start_len sample_index_pt += extend_len - start_len
sample_indices.append(sample_index_pt) sample_indices.append(sample_index_pt)
...@@ -441,6 +481,8 @@ class LogitsProcessor(nn.Module): ...@@ -441,6 +481,8 @@ class LogitsProcessor(nn.Module):
) )
input_logprob_indices_pt += extend_len - start_len input_logprob_indices_pt += extend_len - start_len
# Set the last token of the last sequence
token_to_seq_idx.append(len(logits_metadata.extend_seq_lens_cpu) - 1)
pruned_states = torch.cat(pruned_states) pruned_states = torch.cat(pruned_states)
sample_indices = torch.tensor( sample_indices = torch.tensor(
sample_indices, device=pruned_states.device, dtype=torch.int64 sample_indices, device=pruned_states.device, dtype=torch.int64
...@@ -449,12 +491,6 @@ class LogitsProcessor(nn.Module): ...@@ -449,12 +491,6 @@ class LogitsProcessor(nn.Module):
input_logprob_indices, device=pruned_states.device, dtype=torch.int64 input_logprob_indices, device=pruned_states.device, dtype=torch.int64
) )
# Compute logits for both input and sampled tokens.
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
sampled_logits = (
logits[sample_indices] if sample_indices is not None else logits
)
hidden_states_to_store: Optional[torch.Tensor] = None hidden_states_to_store: Optional[torch.Tensor] = None
if logits_metadata.capture_hidden_mode.need_capture(): if logits_metadata.capture_hidden_mode.need_capture():
if logits_metadata.capture_hidden_mode.is_full(): if logits_metadata.capture_hidden_mode.is_full():
...@@ -482,67 +518,278 @@ class LogitsProcessor(nn.Module): ...@@ -482,67 +518,278 @@ class LogitsProcessor(nn.Module):
else: else:
assert False, "Should never reach" assert False, "Should never reach"
del hidden_states
if not logits_metadata.extend_return_logprob: if not logits_metadata.extend_return_logprob:
# Compute logits for both input and sampled tokens.
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
sampled_logits = (
logits[sample_indices] if sample_indices is not None else logits
)
# Decode mode or extend mode without return_logprob. # Decode mode or extend mode without return_logprob.
return LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=sampled_logits, next_token_logits=sampled_logits,
hidden_states=hidden_states_to_store, hidden_states=hidden_states_to_store,
) )
else:
# Start to process input logprobs
# Normalize the logprob w/o temperature, top-p
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu,
device=pruned_states.device,
)
if logits_metadata.temp_scaled_logprobs:
logits_metadata.temperature = torch.repeat_interleave(
logits_metadata.temperature.view(-1),
pruned_lens,
).view(-1, 1)
if logits_metadata.top_p_normalized_logprobs:
logits_metadata.top_p = torch.repeat_interleave(
logits_metadata.top_p,
pruned_lens,
)
# Determine whether to use chunked or non-chunked logits processing.
# Skip chunking if:
# 1. Chunking is disabled
# 2. Total count is below chunk size threshold
# 3. DP attention all-gather is enabled (can use "enable_dp_lm_head" to enable chunking)
should_skip_chunking = (
not self.enable_logprobs_chunk
or pruned_states.shape[0] <= self.logprobs_chunk_size
or self.do_tensor_parallel_all_gather_dp_attn
)
if should_skip_chunking:
# Compute logits for both input and sampled tokens.
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
sampled_logits = (
logits[sample_indices] if sample_indices is not None else logits
)
input_logprobs = logits[input_logprob_indices] input_logprobs = logits[input_logprob_indices]
del hidden_states, logits del logits
# Normalize the logprob w/o temperature, top-p logprobs_result = self._process_input_logprobs(
pruned_lens = torch.tensor( input_logprobs, logits_metadata
logits_metadata.extend_logprob_pruned_lens_cpu,
device=input_logprobs.device,
) )
if logits_metadata.temp_scaled_logprobs: else:
logits_metadata.temperature = torch.repeat_interleave( (logprobs_result, sampled_logits) = self._process_input_logprobs_by_chunk(
logits_metadata.temperature.view(-1), pruned_states,
pruned_lens, sample_indices,
).view(-1, 1) input_logprob_indices,
if logits_metadata.top_p_normalized_logprobs: token_to_seq_idx,
logits_metadata.top_p = torch.repeat_interleave( lm_head,
logits_metadata.top_p, logits_metadata,
pruned_lens, )
return LogitsProcessorOutput(
next_token_logits=sampled_logits,
hidden_states=hidden_states_to_store,
input_token_logprobs=logprobs_result.input_token_logprobs,
input_top_logprobs_val=logprobs_result.input_top_logprobs_val,
input_top_logprobs_idx=logprobs_result.input_top_logprobs_idx,
input_token_ids_logprobs_val=logprobs_result.input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=logprobs_result.input_token_ids_logprobs_idx,
)
def _process_input_logprobs(self, input_logprobs, logits_metadata):
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
input_logprobs, logits_metadata
)
# Get the logprob of top-k tokens
if logits_metadata.extend_return_top_logprob:
(
input_top_logprobs_val,
input_top_logprobs_idx,
) = self.get_top_logprobs(input_logprobs, logits_metadata)
else:
input_top_logprobs_val = input_top_logprobs_idx = None
# Get the logprob of given token id
if logits_metadata.extend_token_ids_logprob:
(
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
else:
input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
input_token_logprobs = input_logprobs[
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
logits_metadata.extend_input_logprob_token_ids_gpu,
]
return InputLogprobsResult(
input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx,
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
)
def _process_input_logprobs_by_chunk(
self,
pruned_states: torch.Tensor,
sample_indices: torch.Tensor,
input_logprob_indices: torch.Tensor,
token_to_seq_idx: list[int],
lm_head: VocabParallelEmbedding,
logits_metadata: LogitsMetadata,
) -> Tuple[InputLogprobsResult, torch.Tensor]:
"""
compute logprobs for the output token from the hidden states.
To avoid using too much memory, we split pruned_states into chunks of
rows to compute input_logprobs separately, then concatenate the results.
Returns:
InputLogprobsResult: logprobs result
torch.Tensor: sampled logits
"""
# The peak memory usage is proportional to the chunk size.
chunk_size = self.logprobs_chunk_size
total_size = pruned_states.shape[0]
num_chunks = (total_size + chunk_size - 1) // chunk_size
input_token_logprobs = []
if logits_metadata.extend_return_top_logprob:
input_top_logprobs_val = []
input_top_logprobs_idx = []
else:
input_top_logprobs_val = None
input_top_logprobs_idx = None
if logits_metadata.extend_token_ids_logprob:
input_token_ids_logprobs_val = []
input_token_ids_logprobs_idx = []
else:
input_token_ids_logprobs_val = None
input_token_ids_logprobs_idx = None
# If a single sequence is split into multiple chunks, we need to keep track
# of the pruned length of the sequences in the previous chunks.
split_len_topk = 0
split_len_token_ids = 0
for i in range(num_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, total_size)
# Get indices for this chunk
chunk_mask = (input_logprob_indices >= start_idx) & (
input_logprob_indices < end_idx
)
global_indices = input_logprob_indices[chunk_mask]
chunk_indices = global_indices - start_idx
# Get the positions in the original array where chunk_mask is True
# This is needed to correctly index into extend_input_logprob_token_ids_gpu
mask_indices = torch.nonzero(chunk_mask, as_tuple=True)[0]
# Get the logits for this chunk
chunk_states = pruned_states[start_idx:end_idx]
chunk_logits = self._get_logits(chunk_states, lm_head, logits_metadata)
# Initialize sampled_logits on first chunk
if i == 0:
sampled_logits = torch.empty(
(sample_indices.shape[0], chunk_logits.shape[1]),
dtype=chunk_logits.dtype,
device=chunk_logits.device,
) )
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
input_logprobs, logits_metadata # Handle sampled logits for the chunk if needed
# This must be done before the continue statement to ensure all sampled_logits are filled
chunk_sample_mask = (sample_indices >= start_idx) & (
sample_indices < end_idx
)
if chunk_sample_mask.any():
chunk_sample_indices = sample_indices[chunk_sample_mask] - start_idx
sampled_logits[chunk_sample_mask] = chunk_logits[chunk_sample_indices]
# If there are no input logprobs in this chunk, skip the rest
if chunk_indices.numel() == 0:
continue
# Compute the logprobs of the chunk
chunk_input_logprobs = chunk_logits[chunk_indices]
chunk_temperature = (
logits_metadata.temperature[global_indices]
if logits_metadata.temperature is not None
else None
)
chunk_top_p = (
logits_metadata.top_p[global_indices]
if logits_metadata.top_p is not None
else None
)
chunk_input_logprobs = self.compute_temp_top_p_normalized_logprobs(
chunk_input_logprobs,
logits_metadata,
chunk_top_p,
chunk_temperature,
)
# For each chunk, we need to get the slice of the token_to_seq_idx
chunk_slice = slice(
token_to_seq_idx[start_idx], token_to_seq_idx[end_idx] + 1
) )
# Get the logprob of top-k tokens # Get the logprob of top-k tokens
if logits_metadata.extend_return_top_logprob: if logits_metadata.extend_return_top_logprob:
( top_k_nums = logits_metadata.top_logprobs_nums[chunk_slice]
pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[
chunk_slice
]
split_len_topk = self.get_top_logprobs_chunk(
chunk_input_logprobs,
logits_metadata,
top_k_nums,
pruned_lens,
input_top_logprobs_val, input_top_logprobs_val,
input_top_logprobs_idx, input_top_logprobs_idx,
) = self.get_top_logprobs(input_logprobs, logits_metadata) split_len_topk,
else: )
input_top_logprobs_val = input_top_logprobs_idx = None
# Get the logprob of given token id # Get the logprob of given token id
if logits_metadata.extend_token_ids_logprob: if logits_metadata.extend_token_ids_logprob:
( token_ids_logprobs = logits_metadata.token_ids_logprobs[chunk_slice]
pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[
chunk_slice
]
split_len_token_ids = self.get_token_ids_logprobs_chunk(
chunk_input_logprobs,
logits_metadata,
token_ids_logprobs,
pruned_lens,
input_token_ids_logprobs_val, input_token_ids_logprobs_val,
input_token_ids_logprobs_idx, input_token_ids_logprobs_idx,
) = self.get_token_ids_logprobs(input_logprobs, logits_metadata) split_len_token_ids,
else: )
input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
input_token_logprobs = input_logprobs[ # Get the logprob of the requested token ids
torch.arange(input_logprobs.shape[0], device=input_logprobs.device), chunk_input_token_logprobs = chunk_input_logprobs[
logits_metadata.extend_input_logprob_token_ids_gpu, torch.arange(
chunk_input_logprobs.shape[0], device=chunk_input_logprobs.device
),
logits_metadata.extend_input_logprob_token_ids_gpu[mask_indices],
] ]
input_token_logprobs.append(chunk_input_token_logprobs)
return LogitsProcessorOutput( # Concatenate the results
next_token_logits=sampled_logits, input_token_logprobs = torch.cat(input_token_logprobs, dim=0)
return (
InputLogprobsResult(
input_token_logprobs=input_token_logprobs, input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val, input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx, input_top_logprobs_idx=input_top_logprobs_idx,
hidden_states=hidden_states_to_store,
input_token_ids_logprobs_val=input_token_ids_logprobs_val, input_token_ids_logprobs_val=input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx, input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
) ),
sampled_logits,
)
def _get_logits( def _get_logits(
self, self,
...@@ -691,6 +938,80 @@ class LogitsProcessor(nn.Module): ...@@ -691,6 +938,80 @@ class LogitsProcessor(nn.Module):
return input_top_logprobs_val, input_top_logprobs_idx return input_top_logprobs_val, input_top_logprobs_idx
@staticmethod
def get_top_logprobs_chunk(
logprobs: torch.Tensor,
logits_metadata: LogitsMetadata,
top_k_nums: List[int],
pruned_lens: List[int],
input_top_logprobs_val: List,
input_top_logprobs_idx: List,
split_pruned_len: int,
) -> int:
"""Get top-k logprobs for each sequence in the chunk.
Args:
logprobs: Log probabilities tensor of shape [seq_len, vocab_size]
logits_metadata: Metadata containing top-k and pruned length info
top_k_nums: List of top-k numbers for each sequence
pruned_lens: List of pruned lengths for each sequence
input_top_logprobs_val: List to store top-k logprob values
input_top_logprobs_idx: List to store top-k token indices
split_pruned_len: Length of pruned tokens from previous chunk
Returns:
int: Number of remaining tokens to process in next chunk
"""
# No sequences in the chunk
if logprobs.shape[0] == 0:
return 0
max_k = max(logits_metadata.top_logprobs_nums)
ret = logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
pt = 0
next_split_pruned_len = 0
for n, (k, pruned_len) in enumerate(zip(top_k_nums, pruned_lens)):
if n == 0:
# For the first sequence, adjust the pruned length
pruned_len -= split_pruned_len
else:
# After the first sequence, no split in the middle
split_pruned_len = 0
if pruned_len <= 0:
# if pruned length is less than or equal to 0,
# there is no top-k logprobs to process
input_top_logprobs_val.append([])
input_top_logprobs_idx.append([])
continue
# Get the top-k logprobs
val = []
idx = []
for j in range(pruned_len):
# Handle remaining tokens in next chunk if any
if pt + j >= len(values):
next_split_pruned_len = split_pruned_len + j
break
# Append the top-k logprobs
val.append(values[pt + j][:k])
idx.append(indices[pt + j][:k])
# Append or extend based on whether the sequence was split across chunks
if len(val) > 0:
if split_pruned_len > 0:
input_top_logprobs_val[-1].extend(val)
input_top_logprobs_idx[-1].extend(idx)
else:
input_top_logprobs_val.append(val)
input_top_logprobs_idx.append(idx)
pt += pruned_len
return next_split_pruned_len
@staticmethod @staticmethod
def get_token_ids_logprobs( def get_token_ids_logprobs(
all_logprobs: torch.Tensor, all_logprobs: torch.Tensor,
...@@ -724,9 +1045,86 @@ class LogitsProcessor(nn.Module): ...@@ -724,9 +1045,86 @@ class LogitsProcessor(nn.Module):
return input_token_ids_logprobs_val, input_token_ids_logprobs_idx return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
@staticmethod
def get_token_ids_logprobs_chunk(
logprobs: torch.Tensor,
logits_metadata: LogitsMetadata,
token_ids_logprobs: List[int],
pruned_lens: List[int],
input_token_ids_logprobs_val: List,
input_token_ids_logprobs_idx: List,
split_pruned_len: int = 0,
):
"""Get token_ids logprobs for each sequence in the chunk.
Args:
logprobs: Log probabilities tensor of shape [seq_len, vocab_size]
logits_metadata: Metadata containing token IDs and pruned length info
token_ids_logprobs: List of token IDs for each sequence
pruned_lens: List of pruned lengths for each sequence
input_token_ids_logprobs_val: List to store token logprob values
input_token_ids_logprobs_idx: List to store token indices
split_pruned_len: Length of pruned tokens from previous chunk
Returns:
int: Number of remaining tokens to process in next chunk
"""
# No sequences in the chunk
if logprobs.shape[0] == 0:
return 0
pt = 0
next_split_pruned_len = 0
for n, (token_ids, pruned_len) in enumerate(
zip(
token_ids_logprobs,
pruned_lens,
)
):
# Adjust pruned length for first sequence
if n == 0:
pruned_len -= split_pruned_len
else:
split_pruned_len = 0
if pruned_len <= 0:
# if pruned length is less than or equal to 0,
# there is no token ids logprobs to process
input_token_ids_logprobs_val.append([])
input_token_ids_logprobs_idx.append([])
continue
# Get the token ids logprobs
val = []
idx = []
for j in range(pruned_len):
# Handle remaining tokens in next chunk if any
if pt + j >= logprobs.shape[0]:
next_split_pruned_len = split_pruned_len + j
break
if token_ids is not None:
val.append(logprobs[pt + j, token_ids].tolist())
idx.append(token_ids)
# Append or extend based on whether the sequence was split across chunks
if len(val) > 0:
if split_pruned_len > 0:
input_token_ids_logprobs_val[-1].extend(val)
input_token_ids_logprobs_idx[-1].extend(idx)
else:
input_token_ids_logprobs_val.append(val)
input_token_ids_logprobs_idx.append(idx)
pt += pruned_len
return next_split_pruned_len
@staticmethod @staticmethod
def compute_temp_top_p_normalized_logprobs( def compute_temp_top_p_normalized_logprobs(
last_logits: torch.Tensor, logits_metadata: LogitsMetadata last_logits: torch.Tensor,
logits_metadata: LogitsMetadata,
top_p: Optional[torch.Tensor] = None,
temperature: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
compute logprobs for the output token from the given logits. compute logprobs for the output token from the given logits.
...@@ -734,21 +1132,23 @@ class LogitsProcessor(nn.Module): ...@@ -734,21 +1132,23 @@ class LogitsProcessor(nn.Module):
Returns: Returns:
torch.Tensor: logprobs from logits torch.Tensor: logprobs from logits
""" """
if top_p is None:
top_p = logits_metadata.top_p
if temperature is None:
temperature = logits_metadata.temperature
# Scale logits if temperature scaling is enabled # Scale logits if temperature scaling is enabled
if logits_metadata.temp_scaled_logprobs: if logits_metadata.temp_scaled_logprobs:
last_logits = last_logits / logits_metadata.temperature last_logits = last_logits / temperature
# Normalize logprobs if top_p normalization is enabled # Normalize logprobs if top_p normalization is enabled
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0 # NOTE: only normalize logprobs when top_p is set and not equal to 1.0
if ( if logits_metadata.top_p_normalized_logprobs and (top_p != 1.0).any():
logits_metadata.top_p_normalized_logprobs
and (logits_metadata.top_p != 1.0).any()
):
from sglang.srt.layers.sampler import top_p_normalize_probs_torch from sglang.srt.layers.sampler import top_p_normalize_probs_torch
probs = torch.softmax(last_logits, dim=-1) probs = torch.softmax(last_logits, dim=-1)
del last_logits del last_logits
probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p) probs = top_p_normalize_probs_torch(probs, top_p)
return torch.log(probs) return torch.log(probs)
else: else:
return torch.nn.functional.log_softmax(last_logits, dim=-1) return torch.nn.functional.log_softmax(last_logits, dim=-1)
......
...@@ -85,6 +85,16 @@ MAX_LEN = 20000 ...@@ -85,6 +85,16 @@ MAX_LEN = 20000
DEFAULT_BASELINE_PKL = "sglang_baseline_local.pkl" DEFAULT_BASELINE_PKL = "sglang_baseline_local.pkl"
DEFAULT_META_JSON = "baseline_meta_preview.json" DEFAULT_META_JSON = "baseline_meta_preview.json"
# Default engine configuration
DEFAULT_ENGINE_CONFIG = {
"model_path": DENSE_MODEL_NAME,
"random_seed": 42,
"skip_tokenizer_init": True,
"mem_fraction_static": 0.8,
"enable_deterministic_inference": True,
"attention_backend": "flashinfer",
}
def generate_baseline( def generate_baseline(
baseline_file=DEFAULT_BASELINE_PKL, baseline_file=DEFAULT_BASELINE_PKL,
...@@ -213,14 +223,7 @@ class TestLogprobsDense(unittest.TestCase): ...@@ -213,14 +223,7 @@ class TestLogprobsDense(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
"""Set up the test class - initialize the engine once for all tests.""" """Set up the test class - initialize the engine once for all tests."""
print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...") print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...")
cls.engine = sgl.Engine( cls.engine = sgl.Engine(**DEFAULT_ENGINE_CONFIG)
model_path=DENSE_MODEL_NAME,
random_seed=42,
attention_backend="flashinfer",
enable_deterministic_inference=True,
skip_tokenizer_init=True,
mem_fraction_static=0.80,
)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
...@@ -228,6 +231,26 @@ class TestLogprobsDense(unittest.TestCase): ...@@ -228,6 +231,26 @@ class TestLogprobsDense(unittest.TestCase):
cls.engine.shutdown() cls.engine.shutdown()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@classmethod
def restart_engine_with_config(cls, **kwargs):
"""Create engine with custom configuration"""
# Safely shutdown existing engine
cls.engine.shutdown()
torch.cuda.empty_cache()
# Set chunk size
chunk_size = kwargs.pop("chunk_size", None)
if chunk_size is not None:
print(f"Setting chunk size to {chunk_size}")
os.environ["SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK"] = "True"
os.environ["SGLANG_LOGITS_PROCESSER_CHUNK_SIZE"] = str(chunk_size)
else:
os.environ["SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK"] = "False"
# Create engine with merged configuration
engine_config = {**DEFAULT_ENGINE_CONFIG, **kwargs}
cls.engine = sgl.Engine(**engine_config)
def load_test_data(self, baseline_file=None): def load_test_data(self, baseline_file=None):
"""Load test data from local baseline file. In test mode, only local baseline is supported.""" """Load test data from local baseline file. In test mode, only local baseline is supported."""
if not baseline_file: if not baseline_file:
...@@ -281,141 +304,176 @@ class TestLogprobsDense(unittest.TestCase): ...@@ -281,141 +304,176 @@ class TestLogprobsDense(unittest.TestCase):
# Load test data with retry mechanism # Load test data with retry mechanism
records = self.load_test_data(baseline_file) records = self.load_test_data(baseline_file)
with self.subTest( # Fast configs for CI
config={ test_configs = [
{"num_samples": NUM_SAMPLES},
{"num_samples": 42, "chunk_size": 1, "max_running_requests": 16},
{"num_samples": 42, "chunk_size": 2, "max_running_requests": 16},
{"num_samples": 42, "chunk_size": 3, "max_running_requests": 16},
{"num_samples": NUM_SAMPLES, "chunk_size": 16, "max_running_requests": 128},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 16},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 8},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 32},
{
"num_samples": NUM_SAMPLES, "num_samples": NUM_SAMPLES,
"logprob_sample_ratio": LOGPROB_SAMPLE_RATIO, "chunk_size": 128,
"temperature": TEMPERATURE, "max_running_requests": 128,
} },
): {"num_samples": NUM_SAMPLES, "chunk_size": 256, "max_running_requests": 8},
{"num_samples": NUM_SAMPLES, "chunk_size": 256, "max_running_requests": 32},
# Sample records for this config {
test_records = random.sample(records, k=min(NUM_SAMPLES, len(records))) "num_samples": NUM_SAMPLES,
random.shuffle(test_records) "chunk_size": 256,
"max_running_requests": 128,
# Calculate how many samples should return logprobs },
logprob_count = int(len(test_records) * LOGPROB_SAMPLE_RATIO) ]
print(
f"Testing with {len(test_records)} samples, temperature={TEMPERATURE}" # Run tests
) for config in test_configs:
print( with self.subTest(config=config):
f"Will return logprobs for {logprob_count} samples (ratio: {LOGPROB_SAMPLE_RATIO})" print(f"Testing with config: {config}")
)
# Sample records for this config
all_max, all_mean = [], [] test_records = random.sample(records, k=min(NUM_SAMPLES, len(records)))
logprob_returned_count = 0 random.shuffle(test_records)
# Calculate how many samples should return logprobs
logprob_count = int(len(test_records) * LOGPROB_SAMPLE_RATIO)
print(
f"Testing with {len(test_records)} samples, temperature={TEMPERATURE}"
)
print(
f"Will return logprobs for {logprob_count} samples (ratio: {LOGPROB_SAMPLE_RATIO})"
)
# Process all records at once all_max, all_mean = [], []
input_ids = [rec["ids"] for rec in test_records] logprob_returned_count = 0
logprob_start_lens = [rec["start_pos"] for rec in test_records]
# Determine which samples should return logprobs (randomly selected) # Process all records at once
logprob_indices = set( input_ids = [rec["ids"] for rec in test_records]
random.sample(range(len(test_records)), logprob_count) logprob_start_lens = [rec["start_pos"] for rec in test_records]
)
return_logprob_array = [
sample_idx in logprob_indices for sample_idx in range(len(test_records))
]
# Sampling param per request
sampling_params = [
{
"temperature": TEMPERATURE,
"top_p": 1.0,
"top_k": TOP_K,
"max_new_tokens": 1,
}
for _ in test_records
]
outputs = self.engine.generate(
input_ids=input_ids,
sampling_params=sampling_params,
return_logprob=return_logprob_array,
logprob_start_len=logprob_start_lens,
top_logprobs_num=TOP_K,
)
for sample_idx, (rec, output) in enumerate(zip(test_records, outputs)): # Determine which samples should return logprobs (randomly selected)
# Only compare logprobs for samples that should have them logprob_indices = set(
if sample_idx in logprob_indices: random.sample(range(len(test_records)), logprob_count)
# Safe access to meta_info and input_top_logprobs )
meta_info = output.get("meta_info") return_logprob_array = [
input_top_logprobs = ( sample_idx in logprob_indices
meta_info.get("input_top_logprobs") if meta_info else None for sample_idx in range(len(test_records))
) ]
self.assertIsNotNone( # Sampling param per request
input_top_logprobs, sampling_params = [
f"return_logprob enabled on this sample, but input_top_logprobs is None (length: {len(input_top_logprobs) if input_top_logprobs is not None else 'N/A'})", {
) "temperature": TEMPERATURE,
baseline_meta = rec["meta"] "top_p": 1.0,
sglang_meta = meta_info "top_k": TOP_K,
"max_new_tokens": 1,
max_diff, mean_diff = self.compare_meta(baseline_meta, sglang_meta) }
all_max.append(max_diff) for _ in test_records
all_mean.append(mean_diff) ]
logprob_returned_count += 1
else: # Some configs must restart the engine to take effect
# Verify that logprobs were not returned for this sample chunk_size = config.get("chunk_size", None)
meta_info = output.get("meta_info") max_running_requests = config.get("max_running_requests", None)
input_top_logprobs = ( if chunk_size is not None or max_running_requests is not None:
meta_info.get("input_top_logprobs") if meta_info else None self.restart_engine_with_config(
) chunk_size=chunk_size,
output_token_ids_logprobs = ( max_running_requests=max_running_requests,
meta_info.get("output_token_ids_logprobs")
if meta_info
else None
)
self.assertFalse(
input_top_logprobs,
f"return_logprob is disabled on this sample, Sample {sample_idx} should not have logprobs, content: {output_token_ids_logprobs}",
) )
max_of_max = max(all_max) if all_max else 0.0 outputs = self.engine.generate(
mean_of_mean = np.mean(all_mean) if all_mean else 0.0 input_ids=input_ids,
sampling_params=sampling_params,
return_logprob=return_logprob_array,
logprob_start_len=logprob_start_lens,
top_logprobs_num=TOP_K,
)
print(f"max Δ={max_of_max:.6g}") for sample_idx, (rec, output) in enumerate(zip(test_records, outputs)):
print(f"mean Δ={mean_of_mean:.6g}") # Only compare logprobs for samples that should have them
print( if sample_idx in logprob_indices:
f"logprobs returned for {logprob_returned_count} samples (expected: {logprob_count})" # Safe access to meta_info and input_top_logprobs
) meta_info = output.get("meta_info")
input_top_logprobs = (
meta_info.get("input_top_logprobs") if meta_info else None
)
self.assertIsNotNone(
input_top_logprobs,
f"return_logprob enabled on this sample, but input_top_logprobs is None (length: {len(input_top_logprobs) if input_top_logprobs is not None else 'N/A'})",
)
baseline_meta = rec["meta"]
sglang_meta = meta_info
max_diff, mean_diff = self.compare_meta(
baseline_meta, sglang_meta
)
all_max.append(max_diff)
all_mean.append(mean_diff)
logprob_returned_count += 1
else:
# Verify that logprobs were not returned for this sample
meta_info = output.get("meta_info")
input_top_logprobs = (
meta_info.get("input_top_logprobs") if meta_info else None
)
output_token_ids_logprobs = (
meta_info.get("output_token_ids_logprobs")
if meta_info
else None
)
self.assertFalse(
input_top_logprobs,
f"return_logprob is disabled on this sample, Sample {sample_idx} should not have logprobs, content: {output_token_ids_logprobs}",
)
max_of_max = max(all_max) if all_max else 0.0
mean_of_mean = np.mean(all_mean) if all_mean else 0.0
print(f"max Δ={max_of_max:.6g}")
print(f"mean Δ={mean_of_mean:.6g}")
print(
f"logprobs returned for {logprob_returned_count} samples (expected: {logprob_count})"
)
# Verify correct number of logprobs returned # Verify correct number of logprobs returned
self.assertEqual( self.assertEqual(
logprob_returned_count, logprob_returned_count,
logprob_count, logprob_count,
f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}", f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}",
) )
# Basic validation # Basic validation
self.assertIsInstance(all_max, list) self.assertIsInstance(all_max, list)
self.assertIsInstance(all_mean, list) self.assertIsInstance(all_mean, list)
self.assertGreater( self.assertGreater(
len(all_max), len(all_max),
0, 0,
f"No test samples processed for config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}}", f"No test samples processed for config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}}",
) )
# Tolerance checks with clear error messages # Tolerance checks with clear error messages
failed_samples = [] failed_samples = []
for sample_idx, (max_diff, mean_diff) in enumerate(zip(all_max, all_mean)): for sample_idx, (max_diff, mean_diff) in enumerate(
if max_diff > DENSE_TOLERANCE_MAX_DIFF: zip(all_max, all_mean)
failed_samples.append( ):
f"Sample {sample_idx}: max_diff={max_diff:.6g} > {DENSE_TOLERANCE_MAX_DIFF}" if max_diff > DENSE_TOLERANCE_MAX_DIFF:
) failed_samples.append(
if mean_diff > DENSE_TOLERANCE_MEAN_DIFF: f"Sample {sample_idx}: max_diff={max_diff:.6g} > {DENSE_TOLERANCE_MAX_DIFF}"
failed_samples.append( )
f"Sample {sample_idx}: mean_diff={mean_diff:.6g} > {DENSE_TOLERANCE_MEAN_DIFF}" if mean_diff > DENSE_TOLERANCE_MEAN_DIFF:
failed_samples.append(
f"Sample {sample_idx}: mean_diff={mean_diff:.6g} > {DENSE_TOLERANCE_MEAN_DIFF}"
)
if failed_samples:
self.fail(
f"Config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}} - Tolerance exceeded in {len(failed_samples)} samples:\n"
+ "\n".join(failed_samples[:5])
) )
if failed_samples:
self.fail(
f"Config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}} - Tolerance exceeded in {len(failed_samples)} samples:\n"
+ "\n".join(failed_samples[:5])
)
def main(): def main():
"""Main function to handle command line arguments and run either generation or testing.""" """Main function to handle command line arguments and run either generation or testing."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment