"vscode:/vscode.git/clone" did not exist on "e5b3c1f4734e9d5bb28953017bc3f0a6dedc3aa0"
Unverified Commit e2ac7888 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

[2/2] Support deterministic inference for temperature > 0 (#10678)


Co-authored-by: default avatarBaizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: default avatarhebiao064 <hebiaobuaa@gmail.com>
parent 86527a47
import logging
from typing import List, Tuple
from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
......@@ -65,6 +65,7 @@ class Sampler(nn.Module):
return_logprob: bool,
top_logprobs_nums: List[int],
token_ids_logprobs: List[List[int]],
positions: torch.Tensor,
):
"""Run a sampler & compute logprobs and update logits_output accordingly.
......@@ -77,6 +78,8 @@ class Sampler(nn.Module):
batch_next_token_ids: next token IDs. If set, skip sampling and only
compute output logprobs It is used for speculative decoding which
performs sampling in draft workers.
positions: The positions of the tokens in the sequence. Used for deterministic sampling
to get the unique seed for each position.
"""
logits = logits_output.next_token_logits
......@@ -124,6 +127,8 @@ class Sampler(nn.Module):
sampling_info.top_ps,
sampling_info.min_ps,
sampling_info.need_min_p_sampling,
sampling_info.sampling_seed,
positions,
)
else:
raise ValueError(
......@@ -189,6 +194,7 @@ class Sampler(nn.Module):
Optimized for prefill-only scoring requests that need token probabilities
but don't require next token generation.
"""
if logits_output.next_token_logits is None:
logger.warning("No logits available for logprob computation")
return
......@@ -230,8 +236,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
top_ps: torch.Tensor,
min_ps: torch.Tensor,
need_min_p_sampling: bool,
sampling_seed: Optional[torch.Tensor],
positions: torch.Tensor,
):
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
"""
A top-k, top-p and min-p sampling implementation with native pytorch operations.
When sampling_seed is not None, deterministic inference will be enabled, it will sample
with the sampling_seed of each request.
"""
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
probs_sort[
......@@ -243,14 +255,50 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
if need_min_p_sampling:
min_p_thresholds = probs_sort[:, 0] * min_ps
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
sampled_index = torch.multinomial(probs_sort, num_samples=1)
if sampling_seed is not None:
sampled_index = multinomial_with_seed(probs_sort, sampling_seed, positions)
else:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
# int32 range is enough to represent the token ids
probs_idx = probs_idx.to(torch.int32)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
return batch_next_token_ids
def multinomial_with_seed(
inputs: torch.Tensor, seed: torch.Tensor, positions: torch.Tensor
) -> torch.Tensor:
"""
Samples n elements from an input tensor `inputs` of shape (n, m) using
a unique random seed for each row. This is a deterministic batched alternative to
`torch.multinomial`.
Args:
inputs: A float tensor of shape (n, m) representing n categorical
distributions with m categories each. The values are treated
as weights and do not need to sum to 1.
seed: An integer tensor of shape (n,) containing the random seed
for each corresponding row in `inputs`.
positions: The positions of the tokens in the sequence. Used for deterministic sampling
to get the unique seed for each position.
Returns:
A tensor of shape (n,) where the i-th element is an index sampled
from the distribution in `inputs[i]` using `seed[i]`.
"""
n, m = inputs.shape
col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
step_seed = seed * 19349663 ^ positions * 73856093
seed_expanded = step_seed.unsqueeze(-1)
hashed = seed_expanded * 8589934591 ^ col_indices * 479001599
uniform_samples = (hashed % (2**24)).float() / (2**24)
epsilon = 1e-9
gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon)
log_probs = torch.log(inputs + epsilon)
perturbed_log_probs = log_probs + gumbel_noise
return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
def sampling_from_probs_torch(probs: torch.Tensor):
"""A sampling implementation with native pytorch operations, without
top-k, top-p, or min-p filtering."""
......
......@@ -67,7 +67,7 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.sampling.sampling_params import DEFAULT_SAMPLING_SEED, SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, support_triton
......
......@@ -270,9 +270,7 @@ class TpModelWorker:
logits_output, model_worker_batch
)
else:
next_token_ids = self.model_runner.sample(
logits_output, model_worker_batch
)
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
return logits_output, next_token_ids, can_run_cuda_graph
else:
......
......@@ -2049,7 +2049,6 @@ class ModelRunner:
)
self._preprocess_logits(logits_output, forward_batch.sampling_info)
# Sample the next tokens
next_token_ids = self.sampler(
logits_output,
......@@ -2057,6 +2056,12 @@ class ModelRunner:
forward_batch.return_logprob,
forward_batch.top_logprobs_nums,
forward_batch.token_ids_logprobs,
# For prefill, we only use the position of the last token.
(
forward_batch.positions
if forward_batch.forward_mode.is_decode()
else forward_batch.seq_lens - 1
),
)
return next_token_ids
......
......@@ -60,6 +60,9 @@ class SamplingBatchInfo:
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
] = None
# Used for deterministic sampling
sampling_seed: Optional[torch.Tensor] = None
# Device
device: str = "cuda"
......@@ -93,6 +96,15 @@ class SamplingBatchInfo:
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
)
sampling_seed = (
torch.tensor(
[r.sampling_params.sampling_seed for r in reqs],
dtype=torch.int32,
device=device,
)
if enable_deterministic
else None
)
logit_bias = None
if any(r.sampling_params.logit_bias is not None for r in reqs):
......@@ -158,6 +170,7 @@ class SamplingBatchInfo:
top_ps=top_ps,
top_ks=top_ks,
min_ps=min_ps,
sampling_seed=sampling_seed,
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs),
need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs),
......@@ -239,9 +252,11 @@ class SamplingBatchInfo:
"top_ps",
"top_ks",
"min_ps",
"sampling_seed",
]:
value = getattr(self, item, None)
setattr(self, item, value[keep_indices_device])
if value is not None:
setattr(self, item, value[keep_indices_device])
if self.logit_bias is not None:
self.logit_bias = self.logit_bias[keep_indices_device]
......@@ -343,10 +358,12 @@ class SamplingBatchInfo:
"top_ps",
"top_ks",
"min_ps",
"sampling_seed",
]:
self_val = getattr(self, item, None)
other_val = getattr(other, item, None)
setattr(self, item, torch.cat([self_val, other_val]))
if self_val is not None and other_val is not None:
setattr(self, item, torch.cat([self_val, other_val]))
self.is_all_greedy &= other.is_all_greedy
self.need_top_p_sampling |= other.need_top_p_sampling
......
......@@ -15,8 +15,11 @@
from typing import Any, Dict, List, Optional, Union
from sglang.srt.utils import get_bool_env_var
_SAMPLING_EPS = 1e-6
TOP_K_ALL = 1 << 30
DEFAULT_SAMPLING_SEED = 42
class SamplingParams:
......@@ -53,6 +56,7 @@ class SamplingParams:
custom_params: Optional[Dict[str, Any]] = None,
stream_interval: Optional[int] = None,
logit_bias: Optional[Dict[str, float]] = None,
sampling_seed: Optional[int] = None,
) -> None:
self.max_new_tokens = max_new_tokens
self.stop_strs = stop
......@@ -80,6 +84,14 @@ class SamplingParams:
self.custom_params = custom_params
self.stream_interval = stream_interval
self.logit_bias = logit_bias
# Used for deterministic sampling
if (
get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE")
and sampling_seed is None
):
# If deterministic inference is enabled and sampling_seed is not set, use the default seed
sampling_seed = DEFAULT_SAMPLING_SEED
self.sampling_seed = sampling_seed
# Process some special cases
if 0 <= self.temperature < _SAMPLING_EPS:
......
......@@ -988,6 +988,12 @@ class ServerArgs:
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
)
# Check some settings
self.sampling_backend = "pytorch"
logger.warning(
"Sampling backend is set to pytorch for deterministic inference."
)
# Currently, only FA3 supports radix cache. Support for other backends is in progress
if self.attention_backend != "fa3":
self.disable_radix_cache = True
logger.warning(
......
......@@ -29,6 +29,7 @@ class BenchArgs:
port: int = 30000
batch_size: int = 1
temperature: float = 0.0
sampling_seed: int = None
max_new_tokens: int = 100
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
......@@ -45,6 +46,9 @@ class BenchArgs:
parser.add_argument("--port", type=int, default=BenchArgs.port)
parser.add_argument("--n-trials", type=int, default=50)
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
parser.add_argument(
"--sampling-seed", type=int, default=BenchArgs.sampling_seed
)
parser.add_argument(
"--max-new-tokens", type=int, default=BenchArgs.max_new_tokens
)
......@@ -92,6 +96,7 @@ def send_single(
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
"sampling_seed": args.sampling_seed,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
......@@ -140,6 +145,7 @@ def send_mixed(args, batch_size: int):
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
"sampling_seed": args.sampling_seed,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
......@@ -186,6 +192,7 @@ def send_prefix(args, batch_size: int, prompts: List[str]):
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
"sampling_seed": args.sampling_seed,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
......
......@@ -97,6 +97,7 @@ fn default_completion_request() -> CompletionRequest {
lora_path: None,
session_params: None,
return_hidden_states: false,
sampling_seed: None,
other: serde_json::Map::new(),
}
}
......
......@@ -367,6 +367,10 @@ pub struct ChatCompletionRequest {
/// Return model hidden states
#[serde(default)]
pub return_hidden_states: bool,
/// Random seed for sampling for deterministic outputs
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling_seed: Option<u64>,
}
impl GenerationRequest for ChatCompletionRequest {
......@@ -608,6 +612,10 @@ pub struct CompletionRequest {
#[serde(default)]
pub return_hidden_states: bool,
/// Sampling seed for deterministic outputs
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling_seed: Option<u64>,
/// Additional fields including bootstrap info for PD routing
#[serde(flatten)]
pub other: serde_json::Map<String, serde_json::Value>,
......@@ -1749,6 +1757,8 @@ pub struct SamplingParams {
pub stop_token_ids: Option<Vec<i32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub no_stop_trim: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling_seed: Option<u64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
......
......@@ -240,6 +240,7 @@ impl super::super::RouterTrait for OpenAIRouter {
"chat_template_kwargs",
"return_hidden_states",
"repetition_penalty",
"sampling_seed",
] {
obj.remove(key);
}
......
......@@ -68,6 +68,7 @@ fn create_minimal_completion_request() -> CompletionRequest {
lora_path: None,
session_params: None,
return_hidden_states: false,
sampling_seed: None,
other: serde_json::Map::new(),
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment