"git@developer.sourcefind.cn:modelzoo/diffbir_pytorch.git" did not exist on "c18c4eb658fbe32dc475f1f75667ae0186a8ef4f"
Unverified Commit 068e9eae authored by intervitens's avatar intervitens Committed by GitHub
Browse files

Support min-p sampling (#1167)

parent d6aeb9fa
...@@ -45,6 +45,8 @@ temperature: float = 1.0, ...@@ -45,6 +45,8 @@ temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
# Top-k sampling # Top-k sampling
top_k: int = -1, top_k: int = -1,
# Min-p sampling
min_p: float = 0.0,
# Whether to ignore EOS token. # Whether to ignore EOS token.
ignore_eos: bool = False, ignore_eos: bool = False,
# Whether to skip the special tokens during detokenization. # Whether to skip the special tokens during detokenization.
......
...@@ -66,6 +66,7 @@ def gen( ...@@ -66,6 +66,7 @@ def gen(
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None, ignore_eos: Optional[bool] = None,
...@@ -103,6 +104,7 @@ def gen( ...@@ -103,6 +104,7 @@ def gen(
temperature, temperature,
top_p, top_p,
top_k, top_k,
min_p,
frequency_penalty, frequency_penalty,
presence_penalty, presence_penalty,
ignore_eos, ignore_eos,
...@@ -123,6 +125,7 @@ def gen_int( ...@@ -123,6 +125,7 @@ def gen_int(
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None, ignore_eos: Optional[bool] = None,
...@@ -139,6 +142,7 @@ def gen_int( ...@@ -139,6 +142,7 @@ def gen_int(
temperature, temperature,
top_p, top_p,
top_k, top_k,
min_p,
frequency_penalty, frequency_penalty,
presence_penalty, presence_penalty,
ignore_eos, ignore_eos,
...@@ -159,6 +163,7 @@ def gen_string( ...@@ -159,6 +163,7 @@ def gen_string(
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None, ignore_eos: Optional[bool] = None,
...@@ -175,6 +180,7 @@ def gen_string( ...@@ -175,6 +180,7 @@ def gen_string(
temperature, temperature,
top_p, top_p,
top_k, top_k,
min_p,
frequency_penalty, frequency_penalty,
presence_penalty, presence_penalty,
ignore_eos, ignore_eos,
......
...@@ -130,6 +130,7 @@ class CompiledFunction: ...@@ -130,6 +130,7 @@ class CompiledFunction:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
backend=None, backend=None,
...@@ -145,6 +146,7 @@ class CompiledFunction: ...@@ -145,6 +146,7 @@ class CompiledFunction:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
) )
...@@ -160,6 +162,7 @@ class CompiledFunction: ...@@ -160,6 +162,7 @@ class CompiledFunction:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
backend=None, backend=None,
...@@ -178,6 +181,7 @@ class CompiledFunction: ...@@ -178,6 +181,7 @@ class CompiledFunction:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
) )
......
...@@ -663,6 +663,7 @@ class StreamExecutor: ...@@ -663,6 +663,7 @@ class StreamExecutor:
"temperature", "temperature",
"top_p", "top_p",
"top_k", "top_k",
"min_p",
"frequency_penalty", "frequency_penalty",
"presence_penalty", "presence_penalty",
"ignore_eos", "ignore_eos",
......
...@@ -22,6 +22,7 @@ class SglSamplingParams: ...@@ -22,6 +22,7 @@ class SglSamplingParams:
temperature: float = 1.0 temperature: float = 1.0
top_p: float = 1.0 top_p: float = 1.0
top_k: int = -1 # -1 means disable top_k: int = -1 # -1 means disable
min_p: float = 0.0
frequency_penalty: float = 0.0 frequency_penalty: float = 0.0
presence_penalty: float = 0.0 presence_penalty: float = 0.0
ignore_eos: bool = False ignore_eos: bool = False
...@@ -42,6 +43,7 @@ class SglSamplingParams: ...@@ -42,6 +43,7 @@ class SglSamplingParams:
self.temperature, self.temperature,
self.top_p, self.top_p,
self.top_k, self.top_k,
self.min_p,
self.frequency_penalty, self.frequency_penalty,
self.presence_penalty, self.presence_penalty,
self.ignore_eos, self.ignore_eos,
...@@ -114,6 +116,7 @@ class SglSamplingParams: ...@@ -114,6 +116,7 @@ class SglSamplingParams:
"temperature": self.temperature, "temperature": self.temperature,
"top_p": self.top_p, "top_p": self.top_p,
"top_k": self.top_k, "top_k": self.top_k,
"min_p": self.min_p,
"frequency_penalty": self.frequency_penalty, "frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty, "presence_penalty": self.presence_penalty,
"ignore_eos": self.ignore_eos, "ignore_eos": self.ignore_eos,
...@@ -149,6 +152,7 @@ class SglFunction: ...@@ -149,6 +152,7 @@ class SglFunction:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
ignore_eos: bool = False, ignore_eos: bool = False,
...@@ -169,6 +173,7 @@ class SglFunction: ...@@ -169,6 +173,7 @@ class SglFunction:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
...@@ -190,6 +195,7 @@ class SglFunction: ...@@ -190,6 +195,7 @@ class SglFunction:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
ignore_eos: bool = False, ignore_eos: bool = False,
...@@ -228,6 +234,7 @@ class SglFunction: ...@@ -228,6 +234,7 @@ class SglFunction:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
...@@ -408,6 +415,7 @@ class SglGen(SglExpr): ...@@ -408,6 +415,7 @@ class SglGen(SglExpr):
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
min_p: Optional[float] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None, ignore_eos: Optional[bool] = None,
...@@ -428,6 +436,7 @@ class SglGen(SglExpr): ...@@ -428,6 +436,7 @@ class SglGen(SglExpr):
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
min_p=min_p,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
......
...@@ -21,7 +21,12 @@ from typing import List, Optional, Union ...@@ -21,7 +21,12 @@ from typing import List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from flashinfer.sampling import top_k_top_p_sampling_from_probs from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
)
from vllm.distributed import get_tensor_model_parallel_group from vllm.distributed import get_tensor_model_parallel_group
import sglang.srt.sampling.penaltylib as penaltylib import sglang.srt.sampling.penaltylib as penaltylib
...@@ -339,6 +344,7 @@ class ScheduleBatch: ...@@ -339,6 +344,7 @@ class ScheduleBatch:
temperatures: torch.Tensor = None temperatures: torch.Tensor = None
top_ps: torch.Tensor = None top_ps: torch.Tensor = None
top_ks: torch.Tensor = None top_ks: torch.Tensor = None
min_ps: torch.Tensor = None
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
logit_bias: torch.Tensor = None logit_bias: torch.Tensor = None
...@@ -403,6 +409,9 @@ class ScheduleBatch: ...@@ -403,6 +409,9 @@ class ScheduleBatch:
self.top_ks = torch.tensor( self.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
) )
self.min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at # Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
...@@ -701,6 +710,7 @@ class ScheduleBatch: ...@@ -701,6 +710,7 @@ class ScheduleBatch:
"temperatures", "temperatures",
"top_ps", "top_ps",
"top_ks", "top_ks",
"min_ps",
"logit_bias", "logit_bias",
]: ]:
self_val = getattr(self, item, None) self_val = getattr(self, item, None)
...@@ -730,6 +740,7 @@ class ScheduleBatch: ...@@ -730,6 +740,7 @@ class ScheduleBatch:
"temperatures", "temperatures",
"top_ps", "top_ps",
"top_ks", "top_ks",
"min_ps",
]: ]:
self_val = getattr(self, item, None) self_val = getattr(self, item, None)
other_val = getattr(other, item, None) other_val = getattr(other, item, None)
...@@ -780,13 +791,20 @@ class ScheduleBatch: ...@@ -780,13 +791,20 @@ class ScheduleBatch:
uniform_samples = torch.rand( uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device (max_top_k_round, batch_size), device=probs.device
) )
batch_next_token_ids, success = top_k_top_p_sampling_from_probs( if self.min_ps.any():
probs, uniform_samples, self.top_ks, self.top_ps probs = top_k_renorm_prob(probs, self.top_ks)
) probs = top_p_renorm_prob(probs, self.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs(
probs, uniform_samples, self.min_ps
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs, uniform_samples, self.top_ks, self.top_ps
)
else: else:
# Here we provide a slower fallback implementation. # Here we provide a slower fallback implementation.
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch( batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
probs, self.top_ks, self.top_ps probs, self.top_ks, self.top_ps, self.min_ps
) )
if not torch.all(success): if not torch.all(success):
...@@ -810,17 +828,22 @@ class ScheduleBatch: ...@@ -810,17 +828,22 @@ class ScheduleBatch:
return batch_next_token_ids return batch_next_token_ids
def top_k_top_p_sampling_from_probs_torch( def top_k_top_p_min_p_sampling_from_probs_torch(
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor probs: torch.Tensor,
top_ks: torch.Tensor,
top_ps: torch.Tensor,
min_ps: torch.Tensor,
): ):
"""A top-k and top-k sampling implementation with native pytorch operations.""" """A top-k, top-p and min-p sampling implementation with native pytorch operations."""
probs_sort, probs_idx = probs.sort(dim=-1, descending=True) probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1) probs_sum = torch.cumsum(probs_sort, dim=-1)
min_p_thresholds = probs_sort[:, 0] * min_ps
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
probs_sort[ probs_sort[
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
>= top_ks.view(-1, 1) >= top_ks.view(-1, 1)
] = 0.0 ] = 0.0
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
try: try:
sampled_index = torch.multinomial(probs_sort, num_samples=1) sampled_index = torch.multinomial(probs_sort, num_samples=1)
......
...@@ -30,6 +30,7 @@ class SamplingParams: ...@@ -30,6 +30,7 @@ class SamplingParams:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
...@@ -42,6 +43,7 @@ class SamplingParams: ...@@ -42,6 +43,7 @@ class SamplingParams:
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
self.min_p = min_p
self.frequency_penalty = frequency_penalty self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty self.presence_penalty = presence_penalty
self.repetition_penalty = repetition_penalty self.repetition_penalty = repetition_penalty
...@@ -69,6 +71,8 @@ class SamplingParams: ...@@ -69,6 +71,8 @@ class SamplingParams:
) )
if not 0.0 < self.top_p <= 1.0: if not 0.0 < self.top_p <= 1.0:
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
if not 0.0 <= self.min_p <= 1.0:
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
if self.top_k < -1 or self.top_k == 0: if self.top_k < -1 or self.top_k == 0:
raise ValueError( raise ValueError(
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment