Unverified Commit e87557b0 authored by Roy's avatar Roy Committed by GitHub
Browse files

Support Min P Sampler (#1642)

parent dcc543a2
...@@ -71,13 +71,18 @@ class Sampler(nn.Module): ...@@ -71,13 +71,18 @@ class Sampler(nn.Module):
logits.div_(t.unsqueeze(dim=1)) logits.div_(t.unsqueeze(dim=1))
# Apply top-p and top-k truncation. # Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size) top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == logits.shape[0] assert len(top_ps) == len(top_ks) == logits.shape[0]
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks) do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k: if do_top_p or do_top_k:
logits = _apply_top_p_top_k(logits, top_ps, top_ks) logits = _apply_top_p_top_k(logits, top_ps, top_ks)
do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps)
if do_min_p:
logits = _apply_min_p(logits, min_ps)
# We use float32 for probabilities and log probabilities. # We use float32 for probabilities and log probabilities.
# Compute the probabilities. # Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs = torch.softmax(logits, dim=-1, dtype=torch.float)
...@@ -261,15 +266,17 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]: ...@@ -261,15 +266,17 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
return temperatures return temperatures
def _get_top_p_top_k( def _get_top_p_top_k_min_p(
input_metadata: InputMetadata, input_metadata: InputMetadata,
vocab_size: int, vocab_size: int,
) -> Tuple[List[float], List[int]]: ) -> Tuple[List[float], List[int], List[float]]:
top_ps: List[float] = [] top_ps: List[float] = []
top_ks: List[int] = [] top_ks: List[int] = []
min_ps: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
top_p = sampling_params.top_p top_p = sampling_params.top_p
min_p = sampling_params.min_p
# k should not be greater than the vocab size. # k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size) top_k = min(sampling_params.top_k, vocab_size)
# k=-1 means no truncation. # k=-1 means no truncation.
...@@ -279,9 +286,11 @@ def _get_top_p_top_k( ...@@ -279,9 +286,11 @@ def _get_top_p_top_k(
prompt_len = input_metadata.prompt_lens[i] prompt_len = input_metadata.prompt_lens[i]
top_ps += [top_p] * (prompt_len - 1) top_ps += [top_p] * (prompt_len - 1)
top_ks += [top_k] * (prompt_len - 1) top_ks += [top_k] * (prompt_len - 1)
min_ps += [min_p] * (prompt_len - 1)
top_ps += [top_p] * len(seq_ids) top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids) top_ks += [top_k] * len(seq_ids)
return top_ps, top_ks min_ps += [min_p] * len(seq_ids)
return top_ps, top_ks, min_ps
def _apply_top_p_top_k( def _apply_top_p_top_k(
...@@ -313,6 +322,24 @@ def _apply_top_p_top_k( ...@@ -313,6 +322,24 @@ def _apply_top_p_top_k(
return logits return logits
def _apply_min_p(
logits: torch.Tensor,
min_ps: List[float],
) -> torch.Tensor:
"""
Adapted from
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
"""
min_p = torch.tensor(min_ps, dtype=logits.dtype, device=logits.device)
probs = torch.softmax(logits, dim=-1)
top_probs, _ = probs.max(dim=-1, keepdim=True)
scaled_min_p = min_p.unsqueeze(dim=1) * top_probs
tokens_to_remove = probs < scaled_min_p
logits = logits.masked_fill(tokens_to_remove, -float("inf"))
return logits
def _greedy_sample( def _greedy_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]], selected_seq_groups: List[Tuple[List[int], SamplingParams]],
logprobs: torch.Tensor, logprobs: torch.Tensor,
......
...@@ -52,6 +52,9 @@ class SamplingParams: ...@@ -52,6 +52,9 @@ class SamplingParams:
to consider. Must be in (0, 1]. Set to 1 to consider all tokens. to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
top_k: Integer that controls the number of top tokens to consider. Set top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens. to -1 to consider all tokens.
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
use_beam_search: Whether to use beam search instead of sampling. use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length. length_penalty: Float that penalizes sequences based on their length.
Used in beam search. Used in beam search.
...@@ -94,6 +97,7 @@ class SamplingParams: ...@@ -94,6 +97,7 @@ class SamplingParams:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
min_p: int = 0.0,
use_beam_search: bool = False, use_beam_search: bool = False,
length_penalty: float = 1.0, length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False, early_stopping: Union[bool, str] = False,
...@@ -115,6 +119,7 @@ class SamplingParams: ...@@ -115,6 +119,7 @@ class SamplingParams:
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
self.min_p = min_p
self.use_beam_search = use_beam_search self.use_beam_search = use_beam_search
self.length_penalty = length_penalty self.length_penalty = length_penalty
self.early_stopping = early_stopping self.early_stopping = early_stopping
...@@ -167,6 +172,9 @@ class SamplingParams: ...@@ -167,6 +172,9 @@ class SamplingParams:
if self.top_k < -1 or self.top_k == 0: if self.top_k < -1 or self.top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, " raise ValueError(f"top_k must be -1 (disable), or at least 1, "
f"got {self.top_k}.") f"got {self.top_k}.")
if not 0.0 <= self.min_p <= 1.0:
raise ValueError("min_p must be in [0, 1], got "
f"{self.min_p}.")
if self.max_tokens < 1: if self.max_tokens < 1:
raise ValueError( raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.") f"max_tokens must be at least 1, got {self.max_tokens}.")
...@@ -228,6 +236,7 @@ class SamplingParams: ...@@ -228,6 +236,7 @@ class SamplingParams:
f"temperature={self.temperature}, " f"temperature={self.temperature}, "
f"top_p={self.top_p}, " f"top_p={self.top_p}, "
f"top_k={self.top_k}, " f"top_k={self.top_k}, "
f"min_p={self.min_p}, "
f"use_beam_search={self.use_beam_search}, " f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, " f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, " f"early_stopping={self.early_stopping}, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment