Unverified Commit 068e9eae authored by intervitens's avatar intervitens Committed by GitHub
Browse files

Support min-p sampling (#1167)

parent d6aeb9fa
......@@ -45,6 +45,8 @@ temperature: float = 1.0,
top_p: float = 1.0,
# Top-k sampling
top_k: int = -1,
# Min-p sampling
min_p: float = 0.0,
# Whether to ignore EOS token.
ignore_eos: bool = False,
# Whether to skip the special tokens during detokenization.
......
......@@ -66,6 +66,7 @@ def gen(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
......@@ -103,6 +104,7 @@ def gen(
temperature,
top_p,
top_k,
min_p,
frequency_penalty,
presence_penalty,
ignore_eos,
......@@ -123,6 +125,7 @@ def gen_int(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
......@@ -139,6 +142,7 @@ def gen_int(
temperature,
top_p,
top_k,
min_p,
frequency_penalty,
presence_penalty,
ignore_eos,
......@@ -159,6 +163,7 @@ def gen_string(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
......@@ -175,6 +180,7 @@ def gen_string(
temperature,
top_p,
top_k,
min_p,
frequency_penalty,
presence_penalty,
ignore_eos,
......
......@@ -130,6 +130,7 @@ class CompiledFunction:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
......@@ -145,6 +146,7 @@ class CompiledFunction:
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
......@@ -160,6 +162,7 @@ class CompiledFunction:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
backend=None,
......@@ -178,6 +181,7 @@ class CompiledFunction:
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
)
......
......@@ -663,6 +663,7 @@ class StreamExecutor:
"temperature",
"top_p",
"top_k",
"min_p",
"frequency_penalty",
"presence_penalty",
"ignore_eos",
......
......@@ -22,6 +22,7 @@ class SglSamplingParams:
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
min_p: float = 0.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
ignore_eos: bool = False
......@@ -42,6 +43,7 @@ class SglSamplingParams:
self.temperature,
self.top_p,
self.top_k,
self.min_p,
self.frequency_penalty,
self.presence_penalty,
self.ignore_eos,
......@@ -114,6 +116,7 @@ class SglSamplingParams:
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"min_p": self.min_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"ignore_eos": self.ignore_eos,
......@@ -149,6 +152,7 @@ class SglFunction:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
......@@ -169,6 +173,7 @@ class SglFunction:
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
......@@ -190,6 +195,7 @@ class SglFunction:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
......@@ -228,6 +234,7 @@ class SglFunction:
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
......@@ -408,6 +415,7 @@ class SglGen(SglExpr):
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
......@@ -428,6 +436,7 @@ class SglGen(SglExpr):
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
......
......@@ -21,7 +21,12 @@ from typing import List, Optional, Union
import torch
import torch.distributed as dist
from flashinfer.sampling import top_k_top_p_sampling_from_probs
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
)
from vllm.distributed import get_tensor_model_parallel_group
import sglang.srt.sampling.penaltylib as penaltylib
......@@ -339,6 +344,7 @@ class ScheduleBatch:
temperatures: torch.Tensor = None
top_ps: torch.Tensor = None
top_ks: torch.Tensor = None
min_ps: torch.Tensor = None
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
logit_bias: torch.Tensor = None
......@@ -403,6 +409,9 @@ class ScheduleBatch:
self.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
)
self.min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
......@@ -701,6 +710,7 @@ class ScheduleBatch:
"temperatures",
"top_ps",
"top_ks",
"min_ps",
"logit_bias",
]:
self_val = getattr(self, item, None)
......@@ -730,6 +740,7 @@ class ScheduleBatch:
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
self_val = getattr(self, item, None)
other_val = getattr(other, item, None)
......@@ -780,13 +791,20 @@ class ScheduleBatch:
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
if self.min_ps.any():
probs = top_k_renorm_prob(probs, self.top_ks)
probs = top_p_renorm_prob(probs, self.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs(
probs, uniform_samples, self.min_ps
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs, uniform_samples, self.top_ks, self.top_ps
)
else:
# Here we provide a slower fallback implementation.
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
probs, self.top_ks, self.top_ps
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
probs, self.top_ks, self.top_ps, self.min_ps
)
if not torch.all(success):
......@@ -810,17 +828,22 @@ class ScheduleBatch:
return batch_next_token_ids
def top_k_top_p_sampling_from_probs_torch(
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
def top_k_top_p_min_p_sampling_from_probs_torch(
probs: torch.Tensor,
top_ks: torch.Tensor,
top_ps: torch.Tensor,
min_ps: torch.Tensor,
):
"""A top-k and top-k sampling implementation with native pytorch operations."""
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
min_p_thresholds = probs_sort[:, 0] * min_ps
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
probs_sort[
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
>= top_ks.view(-1, 1)
] = 0.0
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
try:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
......
......@@ -30,6 +30,7 @@ class SamplingParams:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repetition_penalty: float = 1.0,
......@@ -42,6 +43,7 @@ class SamplingParams:
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.repetition_penalty = repetition_penalty
......@@ -69,6 +71,8 @@ class SamplingParams:
)
if not 0.0 < self.top_p <= 1.0:
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
if not 0.0 <= self.min_p <= 1.0:
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
if self.top_k < -1 or self.top_k == 0:
raise ValueError(
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment