"vscode:/vscode.git/clone" did not exist on "ec96407c3cf6192c349cb32c932b05fd441d68e2"
Unverified Commit e2ac7888 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

[2/2] Support deterministic inference for temperature > 0 (#10678)


Co-authored-by: default avatarBaizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: default avatarhebiao064 <hebiaobuaa@gmail.com>
parent 86527a47
import logging import logging
from typing import List, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -65,6 +65,7 @@ class Sampler(nn.Module): ...@@ -65,6 +65,7 @@ class Sampler(nn.Module):
return_logprob: bool, return_logprob: bool,
top_logprobs_nums: List[int], top_logprobs_nums: List[int],
token_ids_logprobs: List[List[int]], token_ids_logprobs: List[List[int]],
positions: torch.Tensor,
): ):
"""Run a sampler & compute logprobs and update logits_output accordingly. """Run a sampler & compute logprobs and update logits_output accordingly.
...@@ -77,6 +78,8 @@ class Sampler(nn.Module): ...@@ -77,6 +78,8 @@ class Sampler(nn.Module):
batch_next_token_ids: next token IDs. If set, skip sampling and only batch_next_token_ids: next token IDs. If set, skip sampling and only
compute output logprobs It is used for speculative decoding which compute output logprobs It is used for speculative decoding which
performs sampling in draft workers. performs sampling in draft workers.
positions: The positions of the tokens in the sequence. Used for deterministic sampling
to get the unique seed for each position.
""" """
logits = logits_output.next_token_logits logits = logits_output.next_token_logits
...@@ -124,6 +127,8 @@ class Sampler(nn.Module): ...@@ -124,6 +127,8 @@ class Sampler(nn.Module):
sampling_info.top_ps, sampling_info.top_ps,
sampling_info.min_ps, sampling_info.min_ps,
sampling_info.need_min_p_sampling, sampling_info.need_min_p_sampling,
sampling_info.sampling_seed,
positions,
) )
else: else:
raise ValueError( raise ValueError(
...@@ -189,6 +194,7 @@ class Sampler(nn.Module): ...@@ -189,6 +194,7 @@ class Sampler(nn.Module):
Optimized for prefill-only scoring requests that need token probabilities Optimized for prefill-only scoring requests that need token probabilities
but don't require next token generation. but don't require next token generation.
""" """
if logits_output.next_token_logits is None: if logits_output.next_token_logits is None:
logger.warning("No logits available for logprob computation") logger.warning("No logits available for logprob computation")
return return
...@@ -230,8 +236,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch( ...@@ -230,8 +236,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
top_ps: torch.Tensor, top_ps: torch.Tensor,
min_ps: torch.Tensor, min_ps: torch.Tensor,
need_min_p_sampling: bool, need_min_p_sampling: bool,
sampling_seed: Optional[torch.Tensor],
positions: torch.Tensor,
): ):
"""A top-k, top-p and min-p sampling implementation with native pytorch operations.""" """
A top-k, top-p and min-p sampling implementation with native pytorch operations.
When sampling_seed is not None, deterministic inference will be enabled, it will sample
with the sampling_seed of each request.
"""
probs_sort, probs_idx = probs.sort(dim=-1, descending=True) probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1) probs_sum = torch.cumsum(probs_sort, dim=-1)
probs_sort[ probs_sort[
...@@ -243,14 +255,50 @@ def top_k_top_p_min_p_sampling_from_probs_torch( ...@@ -243,14 +255,50 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
if need_min_p_sampling: if need_min_p_sampling:
min_p_thresholds = probs_sort[:, 0] * min_ps min_p_thresholds = probs_sort[:, 0] * min_ps
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
if sampling_seed is not None:
sampled_index = torch.multinomial(probs_sort, num_samples=1) sampled_index = multinomial_with_seed(probs_sort, sampling_seed, positions)
else:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
# int32 range is enough to represent the token ids # int32 range is enough to represent the token ids
probs_idx = probs_idx.to(torch.int32) probs_idx = probs_idx.to(torch.int32)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
return batch_next_token_ids return batch_next_token_ids
def multinomial_with_seed(
inputs: torch.Tensor, seed: torch.Tensor, positions: torch.Tensor
) -> torch.Tensor:
"""
Samples n elements from an input tensor `inputs` of shape (n, m) using
a unique random seed for each row. This is a deterministic batched alternative to
`torch.multinomial`.
Args:
inputs: A float tensor of shape (n, m) representing n categorical
distributions with m categories each. The values are treated
as weights and do not need to sum to 1.
seed: An integer tensor of shape (n,) containing the random seed
for each corresponding row in `inputs`.
positions: The positions of the tokens in the sequence. Used for deterministic sampling
to get the unique seed for each position.
Returns:
A tensor of shape (n,) where the i-th element is an index sampled
from the distribution in `inputs[i]` using `seed[i]`.
"""
n, m = inputs.shape
col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
step_seed = seed * 19349663 ^ positions * 73856093
seed_expanded = step_seed.unsqueeze(-1)
hashed = seed_expanded * 8589934591 ^ col_indices * 479001599
uniform_samples = (hashed % (2**24)).float() / (2**24)
epsilon = 1e-9
gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon)
log_probs = torch.log(inputs + epsilon)
perturbed_log_probs = log_probs + gumbel_noise
return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
def sampling_from_probs_torch(probs: torch.Tensor): def sampling_from_probs_torch(probs: torch.Tensor):
"""A sampling implementation with native pytorch operations, without """A sampling implementation with native pytorch operations, without
top-k, top-p, or min-p filtering.""" top-k, top-p, or min-p filtering."""
......
...@@ -67,7 +67,7 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache ...@@ -67,7 +67,7 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import DEFAULT_SAMPLING_SEED, SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, support_triton from sglang.srt.utils import flatten_nested_list, support_triton
......
...@@ -270,9 +270,7 @@ class TpModelWorker: ...@@ -270,9 +270,7 @@ class TpModelWorker:
logits_output, model_worker_batch logits_output, model_worker_batch
) )
else: else:
next_token_ids = self.model_runner.sample( next_token_ids = self.model_runner.sample(logits_output, forward_batch)
logits_output, model_worker_batch
)
return logits_output, next_token_ids, can_run_cuda_graph return logits_output, next_token_ids, can_run_cuda_graph
else: else:
......
...@@ -2049,7 +2049,6 @@ class ModelRunner: ...@@ -2049,7 +2049,6 @@ class ModelRunner:
) )
self._preprocess_logits(logits_output, forward_batch.sampling_info) self._preprocess_logits(logits_output, forward_batch.sampling_info)
# Sample the next tokens # Sample the next tokens
next_token_ids = self.sampler( next_token_ids = self.sampler(
logits_output, logits_output,
...@@ -2057,6 +2056,12 @@ class ModelRunner: ...@@ -2057,6 +2056,12 @@ class ModelRunner:
forward_batch.return_logprob, forward_batch.return_logprob,
forward_batch.top_logprobs_nums, forward_batch.top_logprobs_nums,
forward_batch.token_ids_logprobs, forward_batch.token_ids_logprobs,
# For prefill, we only use the position of the last token.
(
forward_batch.positions
if forward_batch.forward_mode.is_decode()
else forward_batch.seq_lens - 1
),
) )
return next_token_ids return next_token_ids
......
...@@ -60,6 +60,9 @@ class SamplingBatchInfo: ...@@ -60,6 +60,9 @@ class SamplingBatchInfo:
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]] Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
] = None ] = None
# Used for deterministic sampling
sampling_seed: Optional[torch.Tensor] = None
# Device # Device
device: str = "cuda" device: str = "cuda"
...@@ -93,6 +96,15 @@ class SamplingBatchInfo: ...@@ -93,6 +96,15 @@ class SamplingBatchInfo:
min_ps = torch.tensor( min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
) )
sampling_seed = (
torch.tensor(
[r.sampling_params.sampling_seed for r in reqs],
dtype=torch.int32,
device=device,
)
if enable_deterministic
else None
)
logit_bias = None logit_bias = None
if any(r.sampling_params.logit_bias is not None for r in reqs): if any(r.sampling_params.logit_bias is not None for r in reqs):
...@@ -158,6 +170,7 @@ class SamplingBatchInfo: ...@@ -158,6 +170,7 @@ class SamplingBatchInfo:
top_ps=top_ps, top_ps=top_ps,
top_ks=top_ks, top_ks=top_ks,
min_ps=min_ps, min_ps=min_ps,
sampling_seed=sampling_seed,
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs), need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs),
need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs), need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs),
...@@ -239,9 +252,11 @@ class SamplingBatchInfo: ...@@ -239,9 +252,11 @@ class SamplingBatchInfo:
"top_ps", "top_ps",
"top_ks", "top_ks",
"min_ps", "min_ps",
"sampling_seed",
]: ]:
value = getattr(self, item, None) value = getattr(self, item, None)
setattr(self, item, value[keep_indices_device]) if value is not None:
setattr(self, item, value[keep_indices_device])
if self.logit_bias is not None: if self.logit_bias is not None:
self.logit_bias = self.logit_bias[keep_indices_device] self.logit_bias = self.logit_bias[keep_indices_device]
...@@ -343,10 +358,12 @@ class SamplingBatchInfo: ...@@ -343,10 +358,12 @@ class SamplingBatchInfo:
"top_ps", "top_ps",
"top_ks", "top_ks",
"min_ps", "min_ps",
"sampling_seed",
]: ]:
self_val = getattr(self, item, None) self_val = getattr(self, item, None)
other_val = getattr(other, item, None) other_val = getattr(other, item, None)
setattr(self, item, torch.cat([self_val, other_val])) if self_val is not None and other_val is not None:
setattr(self, item, torch.cat([self_val, other_val]))
self.is_all_greedy &= other.is_all_greedy self.is_all_greedy &= other.is_all_greedy
self.need_top_p_sampling |= other.need_top_p_sampling self.need_top_p_sampling |= other.need_top_p_sampling
......
...@@ -15,8 +15,11 @@ ...@@ -15,8 +15,11 @@
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from sglang.srt.utils import get_bool_env_var
_SAMPLING_EPS = 1e-6 _SAMPLING_EPS = 1e-6
TOP_K_ALL = 1 << 30 TOP_K_ALL = 1 << 30
DEFAULT_SAMPLING_SEED = 42
class SamplingParams: class SamplingParams:
...@@ -53,6 +56,7 @@ class SamplingParams: ...@@ -53,6 +56,7 @@ class SamplingParams:
custom_params: Optional[Dict[str, Any]] = None, custom_params: Optional[Dict[str, Any]] = None,
stream_interval: Optional[int] = None, stream_interval: Optional[int] = None,
logit_bias: Optional[Dict[str, float]] = None, logit_bias: Optional[Dict[str, float]] = None,
sampling_seed: Optional[int] = None,
) -> None: ) -> None:
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.stop_strs = stop self.stop_strs = stop
...@@ -80,6 +84,14 @@ class SamplingParams: ...@@ -80,6 +84,14 @@ class SamplingParams:
self.custom_params = custom_params self.custom_params = custom_params
self.stream_interval = stream_interval self.stream_interval = stream_interval
self.logit_bias = logit_bias self.logit_bias = logit_bias
# Used for deterministic sampling
if (
get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE")
and sampling_seed is None
):
# If deterministic inference is enabled and sampling_seed is not set, use the default seed
sampling_seed = DEFAULT_SAMPLING_SEED
self.sampling_seed = sampling_seed
# Process some special cases # Process some special cases
if 0 <= self.temperature < _SAMPLING_EPS: if 0 <= self.temperature < _SAMPLING_EPS:
......
...@@ -988,6 +988,12 @@ class ServerArgs: ...@@ -988,6 +988,12 @@ class ServerArgs:
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/." "batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
) )
# Check some settings
self.sampling_backend = "pytorch"
logger.warning(
"Sampling backend is set to pytorch for deterministic inference."
)
# Currently, only FA3 supports radix cache. Support for other backends is in progress
if self.attention_backend != "fa3": if self.attention_backend != "fa3":
self.disable_radix_cache = True self.disable_radix_cache = True
logger.warning( logger.warning(
......
...@@ -29,6 +29,7 @@ class BenchArgs: ...@@ -29,6 +29,7 @@ class BenchArgs:
port: int = 30000 port: int = 30000
batch_size: int = 1 batch_size: int = 1
temperature: float = 0.0 temperature: float = 0.0
sampling_seed: int = None
max_new_tokens: int = 100 max_new_tokens: int = 100
frequency_penalty: float = 0.0 frequency_penalty: float = 0.0
presence_penalty: float = 0.0 presence_penalty: float = 0.0
...@@ -45,6 +46,9 @@ class BenchArgs: ...@@ -45,6 +46,9 @@ class BenchArgs:
parser.add_argument("--port", type=int, default=BenchArgs.port) parser.add_argument("--port", type=int, default=BenchArgs.port)
parser.add_argument("--n-trials", type=int, default=50) parser.add_argument("--n-trials", type=int, default=50)
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature) parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
parser.add_argument(
"--sampling-seed", type=int, default=BenchArgs.sampling_seed
)
parser.add_argument( parser.add_argument(
"--max-new-tokens", type=int, default=BenchArgs.max_new_tokens "--max-new-tokens", type=int, default=BenchArgs.max_new_tokens
) )
...@@ -92,6 +96,7 @@ def send_single( ...@@ -92,6 +96,7 @@ def send_single(
"max_new_tokens": args.max_new_tokens, "max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty, "frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty, "presence_penalty": args.presence_penalty,
"sampling_seed": args.sampling_seed,
}, },
"return_logprob": args.return_logprob, "return_logprob": args.return_logprob,
"stream": args.stream, "stream": args.stream,
...@@ -140,6 +145,7 @@ def send_mixed(args, batch_size: int): ...@@ -140,6 +145,7 @@ def send_mixed(args, batch_size: int):
"max_new_tokens": args.max_new_tokens, "max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty, "frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty, "presence_penalty": args.presence_penalty,
"sampling_seed": args.sampling_seed,
}, },
"return_logprob": args.return_logprob, "return_logprob": args.return_logprob,
"stream": args.stream, "stream": args.stream,
...@@ -186,6 +192,7 @@ def send_prefix(args, batch_size: int, prompts: List[str]): ...@@ -186,6 +192,7 @@ def send_prefix(args, batch_size: int, prompts: List[str]):
"max_new_tokens": args.max_new_tokens, "max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty, "frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty, "presence_penalty": args.presence_penalty,
"sampling_seed": args.sampling_seed,
}, },
"return_logprob": args.return_logprob, "return_logprob": args.return_logprob,
"stream": args.stream, "stream": args.stream,
......
...@@ -97,6 +97,7 @@ fn default_completion_request() -> CompletionRequest { ...@@ -97,6 +97,7 @@ fn default_completion_request() -> CompletionRequest {
lora_path: None, lora_path: None,
session_params: None, session_params: None,
return_hidden_states: false, return_hidden_states: false,
sampling_seed: None,
other: serde_json::Map::new(), other: serde_json::Map::new(),
} }
} }
......
...@@ -367,6 +367,10 @@ pub struct ChatCompletionRequest { ...@@ -367,6 +367,10 @@ pub struct ChatCompletionRequest {
/// Return model hidden states /// Return model hidden states
#[serde(default)] #[serde(default)]
pub return_hidden_states: bool, pub return_hidden_states: bool,
/// Random seed for sampling for deterministic outputs
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling_seed: Option<u64>,
} }
impl GenerationRequest for ChatCompletionRequest { impl GenerationRequest for ChatCompletionRequest {
...@@ -608,6 +612,10 @@ pub struct CompletionRequest { ...@@ -608,6 +612,10 @@ pub struct CompletionRequest {
#[serde(default)] #[serde(default)]
pub return_hidden_states: bool, pub return_hidden_states: bool,
/// Sampling seed for deterministic outputs
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling_seed: Option<u64>,
/// Additional fields including bootstrap info for PD routing /// Additional fields including bootstrap info for PD routing
#[serde(flatten)] #[serde(flatten)]
pub other: serde_json::Map<String, serde_json::Value>, pub other: serde_json::Map<String, serde_json::Value>,
...@@ -1749,6 +1757,8 @@ pub struct SamplingParams { ...@@ -1749,6 +1757,8 @@ pub struct SamplingParams {
pub stop_token_ids: Option<Vec<i32>>, pub stop_token_ids: Option<Vec<i32>>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub no_stop_trim: Option<bool>, pub no_stop_trim: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling_seed: Option<u64>,
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
......
...@@ -240,6 +240,7 @@ impl super::super::RouterTrait for OpenAIRouter { ...@@ -240,6 +240,7 @@ impl super::super::RouterTrait for OpenAIRouter {
"chat_template_kwargs", "chat_template_kwargs",
"return_hidden_states", "return_hidden_states",
"repetition_penalty", "repetition_penalty",
"sampling_seed",
] { ] {
obj.remove(key); obj.remove(key);
} }
......
...@@ -68,6 +68,7 @@ fn create_minimal_completion_request() -> CompletionRequest { ...@@ -68,6 +68,7 @@ fn create_minimal_completion_request() -> CompletionRequest {
lora_path: None, lora_path: None,
session_params: None, session_params: None,
return_hidden_states: false, return_hidden_states: false,
sampling_seed: None,
other: serde_json::Map::new(), other: serde_json::Map::new(),
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment