".github/vscode:/vscode.git/clone" did not exist on "0f9a0a45eeb8ef0f806ef15e400f4c0261669d12"
Unverified Commit d5fa019c authored by Zhao Chen's avatar Zhao Chen Committed by GitHub
Browse files

feat: limit peak memory usage when computing logprobs (#6318)


Signed-off-by: default avatarZhao Chen <zhaochen.zju@gmail.com>
Co-authored-by: default avatar赵晨阳 <zhaochen20@outlook.com>
parent fef3a6b6
......@@ -273,6 +273,10 @@ class Envs:
# Sparse Embeddings
SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None)
# Logits processor
SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK = EnvBool(False)
SGLANG_LOGITS_PROCESSER_CHUNK_SIZE = EnvInt(2048)
# Tool-Call behavior
SGLANG_TOOL_STRICT_LEVEL = EnvInt(ToolStrictLevel.OFF)
......
......@@ -15,7 +15,7 @@
import dataclasses
import logging
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import torch
import triton
......@@ -26,6 +26,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
from sglang.srt.environ import envs
from sglang.srt.layers.dp_attention import (
DpPaddingMode,
attn_tp_all_gather,
......@@ -53,6 +54,15 @@ logger = logging.getLogger(__name__)
_is_npu = is_npu()
@dataclasses.dataclass
class InputLogprobsResult:
input_token_logprobs: torch.Tensor
input_top_logprobs_val: Optional[List] = None
input_top_logprobs_idx: Optional[List] = None
input_token_ids_logprobs_val: Optional[List] = None
input_token_ids_logprobs_idx: Optional[List] = None
@dataclasses.dataclass
class LogitsProcessorOutput:
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
......@@ -248,6 +258,11 @@ class LogitsProcessor(nn.Module):
):
self.final_logit_softcapping = None
# enable chunked logprobs processing
self.enable_logprobs_chunk = envs.SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK.value
# chunk size for logprobs processing
self.logprobs_chunk_size = envs.SGLANG_LOGITS_PROCESSER_CHUNK_SIZE.value
def compute_logprobs_for_multi_item_scoring(
self,
input_ids,
......@@ -405,19 +420,41 @@ class LogitsProcessor(nn.Module):
sample_indices = None
input_logprob_indices = None
else:
# Input logprobs are required.
# Find 3 different indices.
# Prefill with input logprobs.
# Find 4 different indices.
# 1. pruned_states: hidden states that we want logprobs from.
# 2. sample_indices: Indices that have sampled tokens.
# 3. input_logprob_indices: Indices that have input logprob tokens.
# 4. token_to_seq_idx: map each token to its sequence index
#
# Example
# -------
# Suppose a batch (flattened by sequence):
# [t00, t01, t02, t03, t10, t11, t12, t13, t14, t20, t21, t22, t23, t24, t25]
# extend_seq_lens_cpu = [4, 5, 6]
# extend_logprob_start_lens_cpu = [0, 5, 3]
#
# Then, the indices are:
# pruned_states -> [t00, t01, t02, t03, t14, t23, t24, t25]
# sample_indices -> [3, 4, 7]
# input_logprob_indices -> [0, 1, 2, 3, 5, 6, 7]
# token_to_seq_idx -> [0, 0, 0, 0, 1, 2, 2, 2]
#
# If chunk is enabled and chunk_size = 3, the chunks will be computed in a chunked manner:
# [t00, t01, t02], [t03, t14, t23], [t24, t25]
sample_index_pt = -1
sample_indices = []
input_logprob_indices_pt = 0
input_logprob_indices = []
pt, pruned_states = 0, []
for extend_logprob_start_len, extend_len in zip(
logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu,
token_to_seq_idx = []
for idx, (extend_logprob_start_len, extend_len) in enumerate(
zip(
logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu,
)
):
# It can happen in chunked prefill. We still need to sample 1 token,
# But we don't want to include it in input logprob.
......@@ -430,6 +467,9 @@ class LogitsProcessor(nn.Module):
# by a caller.
assert extend_len > start_len
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
# Map each token to its sequence index, for chunked computation
# of input logprobs
token_to_seq_idx.extend([idx] * (extend_len - start_len))
pt += extend_len
sample_index_pt += extend_len - start_len
sample_indices.append(sample_index_pt)
......@@ -441,6 +481,8 @@ class LogitsProcessor(nn.Module):
)
input_logprob_indices_pt += extend_len - start_len
# Set the last token of the last sequence
token_to_seq_idx.append(len(logits_metadata.extend_seq_lens_cpu) - 1)
pruned_states = torch.cat(pruned_states)
sample_indices = torch.tensor(
sample_indices, device=pruned_states.device, dtype=torch.int64
......@@ -449,12 +491,6 @@ class LogitsProcessor(nn.Module):
input_logprob_indices, device=pruned_states.device, dtype=torch.int64
)
# Compute logits for both input and sampled tokens.
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
sampled_logits = (
logits[sample_indices] if sample_indices is not None else logits
)
hidden_states_to_store: Optional[torch.Tensor] = None
if logits_metadata.capture_hidden_mode.need_capture():
if logits_metadata.capture_hidden_mode.is_full():
......@@ -482,67 +518,278 @@ class LogitsProcessor(nn.Module):
else:
assert False, "Should never reach"
del hidden_states
if not logits_metadata.extend_return_logprob:
# Compute logits for both input and sampled tokens.
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
sampled_logits = (
logits[sample_indices] if sample_indices is not None else logits
)
# Decode mode or extend mode without return_logprob.
return LogitsProcessorOutput(
next_token_logits=sampled_logits,
hidden_states=hidden_states_to_store,
)
else:
# Start to process input logprobs
# Normalize the logprob w/o temperature, top-p
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu,
device=pruned_states.device,
)
if logits_metadata.temp_scaled_logprobs:
logits_metadata.temperature = torch.repeat_interleave(
logits_metadata.temperature.view(-1),
pruned_lens,
).view(-1, 1)
if logits_metadata.top_p_normalized_logprobs:
logits_metadata.top_p = torch.repeat_interleave(
logits_metadata.top_p,
pruned_lens,
)
# Determine whether to use chunked or non-chunked logits processing.
# Skip chunking if:
# 1. Chunking is disabled
# 2. Total count is below chunk size threshold
# 3. DP attention all-gather is enabled (can use "enable_dp_lm_head" to enable chunking)
should_skip_chunking = (
not self.enable_logprobs_chunk
or pruned_states.shape[0] <= self.logprobs_chunk_size
or self.do_tensor_parallel_all_gather_dp_attn
)
if should_skip_chunking:
# Compute logits for both input and sampled tokens.
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
sampled_logits = (
logits[sample_indices] if sample_indices is not None else logits
)
input_logprobs = logits[input_logprob_indices]
del hidden_states, logits
del logits
# Normalize the logprob w/o temperature, top-p
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu,
device=input_logprobs.device,
logprobs_result = self._process_input_logprobs(
input_logprobs, logits_metadata
)
if logits_metadata.temp_scaled_logprobs:
logits_metadata.temperature = torch.repeat_interleave(
logits_metadata.temperature.view(-1),
pruned_lens,
).view(-1, 1)
if logits_metadata.top_p_normalized_logprobs:
logits_metadata.top_p = torch.repeat_interleave(
logits_metadata.top_p,
pruned_lens,
else:
(logprobs_result, sampled_logits) = self._process_input_logprobs_by_chunk(
pruned_states,
sample_indices,
input_logprob_indices,
token_to_seq_idx,
lm_head,
logits_metadata,
)
return LogitsProcessorOutput(
next_token_logits=sampled_logits,
hidden_states=hidden_states_to_store,
input_token_logprobs=logprobs_result.input_token_logprobs,
input_top_logprobs_val=logprobs_result.input_top_logprobs_val,
input_top_logprobs_idx=logprobs_result.input_top_logprobs_idx,
input_token_ids_logprobs_val=logprobs_result.input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=logprobs_result.input_token_ids_logprobs_idx,
)
def _process_input_logprobs(self, input_logprobs, logits_metadata):
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
input_logprobs, logits_metadata
)
# Get the logprob of top-k tokens
if logits_metadata.extend_return_top_logprob:
(
input_top_logprobs_val,
input_top_logprobs_idx,
) = self.get_top_logprobs(input_logprobs, logits_metadata)
else:
input_top_logprobs_val = input_top_logprobs_idx = None
# Get the logprob of given token id
if logits_metadata.extend_token_ids_logprob:
(
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
else:
input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
input_token_logprobs = input_logprobs[
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
logits_metadata.extend_input_logprob_token_ids_gpu,
]
return InputLogprobsResult(
input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx,
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
)
def _process_input_logprobs_by_chunk(
self,
pruned_states: torch.Tensor,
sample_indices: torch.Tensor,
input_logprob_indices: torch.Tensor,
token_to_seq_idx: list[int],
lm_head: VocabParallelEmbedding,
logits_metadata: LogitsMetadata,
) -> Tuple[InputLogprobsResult, torch.Tensor]:
"""
compute logprobs for the output token from the hidden states.
To avoid using too much memory, we split pruned_states into chunks of
rows to compute input_logprobs separately, then concatenate the results.
Returns:
InputLogprobsResult: logprobs result
torch.Tensor: sampled logits
"""
# The peak memory usage is proportional to the chunk size.
chunk_size = self.logprobs_chunk_size
total_size = pruned_states.shape[0]
num_chunks = (total_size + chunk_size - 1) // chunk_size
input_token_logprobs = []
if logits_metadata.extend_return_top_logprob:
input_top_logprobs_val = []
input_top_logprobs_idx = []
else:
input_top_logprobs_val = None
input_top_logprobs_idx = None
if logits_metadata.extend_token_ids_logprob:
input_token_ids_logprobs_val = []
input_token_ids_logprobs_idx = []
else:
input_token_ids_logprobs_val = None
input_token_ids_logprobs_idx = None
# If a single sequence is split into multiple chunks, we need to keep track
# of the pruned length of the sequences in the previous chunks.
split_len_topk = 0
split_len_token_ids = 0
for i in range(num_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, total_size)
# Get indices for this chunk
chunk_mask = (input_logprob_indices >= start_idx) & (
input_logprob_indices < end_idx
)
global_indices = input_logprob_indices[chunk_mask]
chunk_indices = global_indices - start_idx
# Get the positions in the original array where chunk_mask is True
# This is needed to correctly index into extend_input_logprob_token_ids_gpu
mask_indices = torch.nonzero(chunk_mask, as_tuple=True)[0]
# Get the logits for this chunk
chunk_states = pruned_states[start_idx:end_idx]
chunk_logits = self._get_logits(chunk_states, lm_head, logits_metadata)
# Initialize sampled_logits on first chunk
if i == 0:
sampled_logits = torch.empty(
(sample_indices.shape[0], chunk_logits.shape[1]),
dtype=chunk_logits.dtype,
device=chunk_logits.device,
)
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
input_logprobs, logits_metadata
# Handle sampled logits for the chunk if needed
# This must be done before the continue statement to ensure all sampled_logits are filled
chunk_sample_mask = (sample_indices >= start_idx) & (
sample_indices < end_idx
)
if chunk_sample_mask.any():
chunk_sample_indices = sample_indices[chunk_sample_mask] - start_idx
sampled_logits[chunk_sample_mask] = chunk_logits[chunk_sample_indices]
# If there are no input logprobs in this chunk, skip the rest
if chunk_indices.numel() == 0:
continue
# Compute the logprobs of the chunk
chunk_input_logprobs = chunk_logits[chunk_indices]
chunk_temperature = (
logits_metadata.temperature[global_indices]
if logits_metadata.temperature is not None
else None
)
chunk_top_p = (
logits_metadata.top_p[global_indices]
if logits_metadata.top_p is not None
else None
)
chunk_input_logprobs = self.compute_temp_top_p_normalized_logprobs(
chunk_input_logprobs,
logits_metadata,
chunk_top_p,
chunk_temperature,
)
# For each chunk, we need to get the slice of the token_to_seq_idx
chunk_slice = slice(
token_to_seq_idx[start_idx], token_to_seq_idx[end_idx] + 1
)
# Get the logprob of top-k tokens
if logits_metadata.extend_return_top_logprob:
(
top_k_nums = logits_metadata.top_logprobs_nums[chunk_slice]
pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[
chunk_slice
]
split_len_topk = self.get_top_logprobs_chunk(
chunk_input_logprobs,
logits_metadata,
top_k_nums,
pruned_lens,
input_top_logprobs_val,
input_top_logprobs_idx,
) = self.get_top_logprobs(input_logprobs, logits_metadata)
else:
input_top_logprobs_val = input_top_logprobs_idx = None
split_len_topk,
)
# Get the logprob of given token id
if logits_metadata.extend_token_ids_logprob:
(
token_ids_logprobs = logits_metadata.token_ids_logprobs[chunk_slice]
pruned_lens = logits_metadata.extend_logprob_pruned_lens_cpu[
chunk_slice
]
split_len_token_ids = self.get_token_ids_logprobs_chunk(
chunk_input_logprobs,
logits_metadata,
token_ids_logprobs,
pruned_lens,
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
else:
input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
split_len_token_ids,
)
input_token_logprobs = input_logprobs[
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
logits_metadata.extend_input_logprob_token_ids_gpu,
# Get the logprob of the requested token ids
chunk_input_token_logprobs = chunk_input_logprobs[
torch.arange(
chunk_input_logprobs.shape[0], device=chunk_input_logprobs.device
),
logits_metadata.extend_input_logprob_token_ids_gpu[mask_indices],
]
input_token_logprobs.append(chunk_input_token_logprobs)
return LogitsProcessorOutput(
next_token_logits=sampled_logits,
# Concatenate the results
input_token_logprobs = torch.cat(input_token_logprobs, dim=0)
return (
InputLogprobsResult(
input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx,
hidden_states=hidden_states_to_store,
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
)
),
sampled_logits,
)
def _get_logits(
self,
......@@ -691,6 +938,80 @@ class LogitsProcessor(nn.Module):
return input_top_logprobs_val, input_top_logprobs_idx
@staticmethod
def get_top_logprobs_chunk(
logprobs: torch.Tensor,
logits_metadata: LogitsMetadata,
top_k_nums: List[int],
pruned_lens: List[int],
input_top_logprobs_val: List,
input_top_logprobs_idx: List,
split_pruned_len: int,
) -> int:
"""Get top-k logprobs for each sequence in the chunk.
Args:
logprobs: Log probabilities tensor of shape [seq_len, vocab_size]
logits_metadata: Metadata containing top-k and pruned length info
top_k_nums: List of top-k numbers for each sequence
pruned_lens: List of pruned lengths for each sequence
input_top_logprobs_val: List to store top-k logprob values
input_top_logprobs_idx: List to store top-k token indices
split_pruned_len: Length of pruned tokens from previous chunk
Returns:
int: Number of remaining tokens to process in next chunk
"""
# No sequences in the chunk
if logprobs.shape[0] == 0:
return 0
max_k = max(logits_metadata.top_logprobs_nums)
ret = logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
pt = 0
next_split_pruned_len = 0
for n, (k, pruned_len) in enumerate(zip(top_k_nums, pruned_lens)):
if n == 0:
# For the first sequence, adjust the pruned length
pruned_len -= split_pruned_len
else:
# After the first sequence, no split in the middle
split_pruned_len = 0
if pruned_len <= 0:
# if pruned length is less than or equal to 0,
# there is no top-k logprobs to process
input_top_logprobs_val.append([])
input_top_logprobs_idx.append([])
continue
# Get the top-k logprobs
val = []
idx = []
for j in range(pruned_len):
# Handle remaining tokens in next chunk if any
if pt + j >= len(values):
next_split_pruned_len = split_pruned_len + j
break
# Append the top-k logprobs
val.append(values[pt + j][:k])
idx.append(indices[pt + j][:k])
# Append or extend based on whether the sequence was split across chunks
if len(val) > 0:
if split_pruned_len > 0:
input_top_logprobs_val[-1].extend(val)
input_top_logprobs_idx[-1].extend(idx)
else:
input_top_logprobs_val.append(val)
input_top_logprobs_idx.append(idx)
pt += pruned_len
return next_split_pruned_len
@staticmethod
def get_token_ids_logprobs(
all_logprobs: torch.Tensor,
......@@ -724,9 +1045,86 @@ class LogitsProcessor(nn.Module):
return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
@staticmethod
def get_token_ids_logprobs_chunk(
logprobs: torch.Tensor,
logits_metadata: LogitsMetadata,
token_ids_logprobs: List[int],
pruned_lens: List[int],
input_token_ids_logprobs_val: List,
input_token_ids_logprobs_idx: List,
split_pruned_len: int = 0,
):
"""Get token_ids logprobs for each sequence in the chunk.
Args:
logprobs: Log probabilities tensor of shape [seq_len, vocab_size]
logits_metadata: Metadata containing token IDs and pruned length info
token_ids_logprobs: List of token IDs for each sequence
pruned_lens: List of pruned lengths for each sequence
input_token_ids_logprobs_val: List to store token logprob values
input_token_ids_logprobs_idx: List to store token indices
split_pruned_len: Length of pruned tokens from previous chunk
Returns:
int: Number of remaining tokens to process in next chunk
"""
# No sequences in the chunk
if logprobs.shape[0] == 0:
return 0
pt = 0
next_split_pruned_len = 0
for n, (token_ids, pruned_len) in enumerate(
zip(
token_ids_logprobs,
pruned_lens,
)
):
# Adjust pruned length for first sequence
if n == 0:
pruned_len -= split_pruned_len
else:
split_pruned_len = 0
if pruned_len <= 0:
# if pruned length is less than or equal to 0,
# there is no token ids logprobs to process
input_token_ids_logprobs_val.append([])
input_token_ids_logprobs_idx.append([])
continue
# Get the token ids logprobs
val = []
idx = []
for j in range(pruned_len):
# Handle remaining tokens in next chunk if any
if pt + j >= logprobs.shape[0]:
next_split_pruned_len = split_pruned_len + j
break
if token_ids is not None:
val.append(logprobs[pt + j, token_ids].tolist())
idx.append(token_ids)
# Append or extend based on whether the sequence was split across chunks
if len(val) > 0:
if split_pruned_len > 0:
input_token_ids_logprobs_val[-1].extend(val)
input_token_ids_logprobs_idx[-1].extend(idx)
else:
input_token_ids_logprobs_val.append(val)
input_token_ids_logprobs_idx.append(idx)
pt += pruned_len
return next_split_pruned_len
@staticmethod
def compute_temp_top_p_normalized_logprobs(
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
last_logits: torch.Tensor,
logits_metadata: LogitsMetadata,
top_p: Optional[torch.Tensor] = None,
temperature: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
compute logprobs for the output token from the given logits.
......@@ -734,21 +1132,23 @@ class LogitsProcessor(nn.Module):
Returns:
torch.Tensor: logprobs from logits
"""
if top_p is None:
top_p = logits_metadata.top_p
if temperature is None:
temperature = logits_metadata.temperature
# Scale logits if temperature scaling is enabled
if logits_metadata.temp_scaled_logprobs:
last_logits = last_logits / logits_metadata.temperature
last_logits = last_logits / temperature
# Normalize logprobs if top_p normalization is enabled
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
if (
logits_metadata.top_p_normalized_logprobs
and (logits_metadata.top_p != 1.0).any()
):
if logits_metadata.top_p_normalized_logprobs and (top_p != 1.0).any():
from sglang.srt.layers.sampler import top_p_normalize_probs_torch
probs = torch.softmax(last_logits, dim=-1)
del last_logits
probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
probs = top_p_normalize_probs_torch(probs, top_p)
return torch.log(probs)
else:
return torch.nn.functional.log_softmax(last_logits, dim=-1)
......
......@@ -85,6 +85,16 @@ MAX_LEN = 20000
DEFAULT_BASELINE_PKL = "sglang_baseline_local.pkl"
DEFAULT_META_JSON = "baseline_meta_preview.json"
# Default engine configuration
DEFAULT_ENGINE_CONFIG = {
"model_path": DENSE_MODEL_NAME,
"random_seed": 42,
"skip_tokenizer_init": True,
"mem_fraction_static": 0.8,
"enable_deterministic_inference": True,
"attention_backend": "flashinfer",
}
def generate_baseline(
baseline_file=DEFAULT_BASELINE_PKL,
......@@ -213,14 +223,7 @@ class TestLogprobsDense(unittest.TestCase):
def setUpClass(cls):
"""Set up the test class - initialize the engine once for all tests."""
print(f"Launching SGLang Engine with {DENSE_MODEL_NAME}...")
cls.engine = sgl.Engine(
model_path=DENSE_MODEL_NAME,
random_seed=42,
attention_backend="flashinfer",
enable_deterministic_inference=True,
skip_tokenizer_init=True,
mem_fraction_static=0.80,
)
cls.engine = sgl.Engine(**DEFAULT_ENGINE_CONFIG)
@classmethod
def tearDownClass(cls):
......@@ -228,6 +231,26 @@ class TestLogprobsDense(unittest.TestCase):
cls.engine.shutdown()
torch.cuda.empty_cache()
@classmethod
def restart_engine_with_config(cls, **kwargs):
"""Create engine with custom configuration"""
# Safely shutdown existing engine
cls.engine.shutdown()
torch.cuda.empty_cache()
# Set chunk size
chunk_size = kwargs.pop("chunk_size", None)
if chunk_size is not None:
print(f"Setting chunk size to {chunk_size}")
os.environ["SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK"] = "True"
os.environ["SGLANG_LOGITS_PROCESSER_CHUNK_SIZE"] = str(chunk_size)
else:
os.environ["SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK"] = "False"
# Create engine with merged configuration
engine_config = {**DEFAULT_ENGINE_CONFIG, **kwargs}
cls.engine = sgl.Engine(**engine_config)
def load_test_data(self, baseline_file=None):
"""Load test data from local baseline file. In test mode, only local baseline is supported."""
if not baseline_file:
......@@ -281,141 +304,176 @@ class TestLogprobsDense(unittest.TestCase):
# Load test data with retry mechanism
records = self.load_test_data(baseline_file)
with self.subTest(
config={
# Fast configs for CI
test_configs = [
{"num_samples": NUM_SAMPLES},
{"num_samples": 42, "chunk_size": 1, "max_running_requests": 16},
{"num_samples": 42, "chunk_size": 2, "max_running_requests": 16},
{"num_samples": 42, "chunk_size": 3, "max_running_requests": 16},
{"num_samples": NUM_SAMPLES, "chunk_size": 16, "max_running_requests": 128},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 16},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 8},
{"num_samples": NUM_SAMPLES, "chunk_size": 128, "max_running_requests": 32},
{
"num_samples": NUM_SAMPLES,
"logprob_sample_ratio": LOGPROB_SAMPLE_RATIO,
"temperature": TEMPERATURE,
}
):
# Sample records for this config
test_records = random.sample(records, k=min(NUM_SAMPLES, len(records)))
random.shuffle(test_records)
# Calculate how many samples should return logprobs
logprob_count = int(len(test_records) * LOGPROB_SAMPLE_RATIO)
print(
f"Testing with {len(test_records)} samples, temperature={TEMPERATURE}"
)
print(
f"Will return logprobs for {logprob_count} samples (ratio: {LOGPROB_SAMPLE_RATIO})"
)
all_max, all_mean = [], []
logprob_returned_count = 0
"chunk_size": 128,
"max_running_requests": 128,
},
{"num_samples": NUM_SAMPLES, "chunk_size": 256, "max_running_requests": 8},
{"num_samples": NUM_SAMPLES, "chunk_size": 256, "max_running_requests": 32},
{
"num_samples": NUM_SAMPLES,
"chunk_size": 256,
"max_running_requests": 128,
},
]
# Run tests
for config in test_configs:
with self.subTest(config=config):
print(f"Testing with config: {config}")
# Sample records for this config
test_records = random.sample(records, k=min(NUM_SAMPLES, len(records)))
random.shuffle(test_records)
# Calculate how many samples should return logprobs
logprob_count = int(len(test_records) * LOGPROB_SAMPLE_RATIO)
print(
f"Testing with {len(test_records)} samples, temperature={TEMPERATURE}"
)
print(
f"Will return logprobs for {logprob_count} samples (ratio: {LOGPROB_SAMPLE_RATIO})"
)
# Process all records at once
input_ids = [rec["ids"] for rec in test_records]
logprob_start_lens = [rec["start_pos"] for rec in test_records]
all_max, all_mean = [], []
logprob_returned_count = 0
# Determine which samples should return logprobs (randomly selected)
logprob_indices = set(
random.sample(range(len(test_records)), logprob_count)
)
return_logprob_array = [
sample_idx in logprob_indices for sample_idx in range(len(test_records))
]
# Sampling param per request
sampling_params = [
{
"temperature": TEMPERATURE,
"top_p": 1.0,
"top_k": TOP_K,
"max_new_tokens": 1,
}
for _ in test_records
]
outputs = self.engine.generate(
input_ids=input_ids,
sampling_params=sampling_params,
return_logprob=return_logprob_array,
logprob_start_len=logprob_start_lens,
top_logprobs_num=TOP_K,
)
# Process all records at once
input_ids = [rec["ids"] for rec in test_records]
logprob_start_lens = [rec["start_pos"] for rec in test_records]
for sample_idx, (rec, output) in enumerate(zip(test_records, outputs)):
# Only compare logprobs for samples that should have them
if sample_idx in logprob_indices:
# Safe access to meta_info and input_top_logprobs
meta_info = output.get("meta_info")
input_top_logprobs = (
meta_info.get("input_top_logprobs") if meta_info else None
)
self.assertIsNotNone(
input_top_logprobs,
f"return_logprob enabled on this sample, but input_top_logprobs is None (length: {len(input_top_logprobs) if input_top_logprobs is not None else 'N/A'})",
)
baseline_meta = rec["meta"]
sglang_meta = meta_info
max_diff, mean_diff = self.compare_meta(baseline_meta, sglang_meta)
all_max.append(max_diff)
all_mean.append(mean_diff)
logprob_returned_count += 1
else:
# Verify that logprobs were not returned for this sample
meta_info = output.get("meta_info")
input_top_logprobs = (
meta_info.get("input_top_logprobs") if meta_info else None
)
output_token_ids_logprobs = (
meta_info.get("output_token_ids_logprobs")
if meta_info
else None
)
self.assertFalse(
input_top_logprobs,
f"return_logprob is disabled on this sample, Sample {sample_idx} should not have logprobs, content: {output_token_ids_logprobs}",
# Determine which samples should return logprobs (randomly selected)
logprob_indices = set(
random.sample(range(len(test_records)), logprob_count)
)
return_logprob_array = [
sample_idx in logprob_indices
for sample_idx in range(len(test_records))
]
# Sampling param per request
sampling_params = [
{
"temperature": TEMPERATURE,
"top_p": 1.0,
"top_k": TOP_K,
"max_new_tokens": 1,
}
for _ in test_records
]
# Some configs must restart the engine to take effect
chunk_size = config.get("chunk_size", None)
max_running_requests = config.get("max_running_requests", None)
if chunk_size is not None or max_running_requests is not None:
self.restart_engine_with_config(
chunk_size=chunk_size,
max_running_requests=max_running_requests,
)
max_of_max = max(all_max) if all_max else 0.0
mean_of_mean = np.mean(all_mean) if all_mean else 0.0
outputs = self.engine.generate(
input_ids=input_ids,
sampling_params=sampling_params,
return_logprob=return_logprob_array,
logprob_start_len=logprob_start_lens,
top_logprobs_num=TOP_K,
)
print(f"max Δ={max_of_max:.6g}")
print(f"mean Δ={mean_of_mean:.6g}")
print(
f"logprobs returned for {logprob_returned_count} samples (expected: {logprob_count})"
)
for sample_idx, (rec, output) in enumerate(zip(test_records, outputs)):
# Only compare logprobs for samples that should have them
if sample_idx in logprob_indices:
# Safe access to meta_info and input_top_logprobs
meta_info = output.get("meta_info")
input_top_logprobs = (
meta_info.get("input_top_logprobs") if meta_info else None
)
self.assertIsNotNone(
input_top_logprobs,
f"return_logprob enabled on this sample, but input_top_logprobs is None (length: {len(input_top_logprobs) if input_top_logprobs is not None else 'N/A'})",
)
baseline_meta = rec["meta"]
sglang_meta = meta_info
max_diff, mean_diff = self.compare_meta(
baseline_meta, sglang_meta
)
all_max.append(max_diff)
all_mean.append(mean_diff)
logprob_returned_count += 1
else:
# Verify that logprobs were not returned for this sample
meta_info = output.get("meta_info")
input_top_logprobs = (
meta_info.get("input_top_logprobs") if meta_info else None
)
output_token_ids_logprobs = (
meta_info.get("output_token_ids_logprobs")
if meta_info
else None
)
self.assertFalse(
input_top_logprobs,
f"return_logprob is disabled on this sample, Sample {sample_idx} should not have logprobs, content: {output_token_ids_logprobs}",
)
max_of_max = max(all_max) if all_max else 0.0
mean_of_mean = np.mean(all_mean) if all_mean else 0.0
print(f"max Δ={max_of_max:.6g}")
print(f"mean Δ={mean_of_mean:.6g}")
print(
f"logprobs returned for {logprob_returned_count} samples (expected: {logprob_count})"
)
# Verify correct number of logprobs returned
self.assertEqual(
logprob_returned_count,
logprob_count,
f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}",
)
# Verify correct number of logprobs returned
self.assertEqual(
logprob_returned_count,
logprob_count,
f"Expected {logprob_count} samples with logprobs, got {logprob_returned_count}",
)
# Basic validation
self.assertIsInstance(all_max, list)
self.assertIsInstance(all_mean, list)
self.assertGreater(
len(all_max),
0,
f"No test samples processed for config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}}",
)
# Basic validation
self.assertIsInstance(all_max, list)
self.assertIsInstance(all_mean, list)
self.assertGreater(
len(all_max),
0,
f"No test samples processed for config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}}",
)
# Tolerance checks with clear error messages
failed_samples = []
for sample_idx, (max_diff, mean_diff) in enumerate(zip(all_max, all_mean)):
if max_diff > DENSE_TOLERANCE_MAX_DIFF:
failed_samples.append(
f"Sample {sample_idx}: max_diff={max_diff:.6g} > {DENSE_TOLERANCE_MAX_DIFF}"
)
if mean_diff > DENSE_TOLERANCE_MEAN_DIFF:
failed_samples.append(
f"Sample {sample_idx}: mean_diff={mean_diff:.6g} > {DENSE_TOLERANCE_MEAN_DIFF}"
# Tolerance checks with clear error messages
failed_samples = []
for sample_idx, (max_diff, mean_diff) in enumerate(
zip(all_max, all_mean)
):
if max_diff > DENSE_TOLERANCE_MAX_DIFF:
failed_samples.append(
f"Sample {sample_idx}: max_diff={max_diff:.6g} > {DENSE_TOLERANCE_MAX_DIFF}"
)
if mean_diff > DENSE_TOLERANCE_MEAN_DIFF:
failed_samples.append(
f"Sample {sample_idx}: mean_diff={mean_diff:.6g} > {DENSE_TOLERANCE_MEAN_DIFF}"
)
if failed_samples:
self.fail(
f"Config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}} - Tolerance exceeded in {len(failed_samples)} samples:\n"
+ "\n".join(failed_samples[:5])
)
if failed_samples:
self.fail(
f"Config {{'num_samples': {NUM_SAMPLES}, 'logprob_sample_ratio': {LOGPROB_SAMPLE_RATIO}, 'temperature': {TEMPERATURE}}} - Tolerance exceeded in {len(failed_samples)} samples:\n"
+ "\n".join(failed_samples[:5])
)
def main():
"""Main function to handle command line arguments and run either generation or testing."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment