Unverified Commit 09b7c26b authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): add frequency penalty (#1541)

parent 39af000c
......@@ -118,6 +118,62 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
return None
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
r"""
Frequency penalty as defined by OpenAI
Args:
penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty.
"""
def __init__(self, penalty: float):
self.penalty = penalty
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(
score < 0, score * self.penalty, score / self.penalty
)
return scores.scatter_add_(1, input_ids, score)
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
r"""
Frequency penalty as defined by OpenAI
Args:
frequency_penalty (`List[float]`):
The parameter for frequency penalty. 0.0 means no penalty.
"""
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
self.penalty = penalty
self.penalty_tensor = torch.tensor(
penalty, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
)
return scores.scatter_add_(1, input_ids, score)
def filter(self, indices):
self.penalty = [self.penalty[i] for i in indices]
if any([x != 0.0 for x in self.penalty]):
self.penalty_tensor = self.penalty_tensor[indices]
return self
return None
class HeterogeneousTemperatureLogitsWarper:
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).
......
import re
from typing import Callable, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.logits_process import (
FrequencyPenaltyLogitsProcessor,
HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousFrequencyPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper,
......@@ -23,6 +25,7 @@ class NextTokenChooser:
watermark=False,
temperature=1.0,
repetition_penalty=1.0,
frequency_penalty=0.0,
top_k=None,
top_p=None,
typical_p=None,
......@@ -35,7 +38,12 @@ class NextTokenChooser:
)
self.repetition_processor = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty
if repetition_penalty and repetition_penalty != 1.0
else None
)
self.frequency_processor = (
FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
if frequency_penalty and frequency_penalty != 0.0
else None
)
......@@ -60,6 +68,8 @@ class NextTokenChooser:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores)
if self.frequency_processor is not None:
scores = self.frequency_processor(input_ids, scores)
if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1)
......@@ -80,6 +90,7 @@ class NextTokenChooser:
watermark=pb.watermark,
temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty,
frequency_penalty=pb.frequency_penalty,
top_k=pb.top_k,
top_p=pb.top_p,
typical_p=pb.typical_p,
......@@ -184,6 +195,7 @@ class HeterogeneousNextTokenChooser:
watermark: List[bool],
temperature: List[float],
repetition_penalty: List[float],
frequency_penalty: List[float],
top_k: List[int],
top_p: List[float],
typical_p: List[float],
......@@ -212,6 +224,14 @@ class HeterogeneousNextTokenChooser:
else None
)
self.frequency_processor = (
HeterogeneousFrequencyPenaltyLogitsProcessor(
frequency_penalty, dtype, device
)
if any([x != 0.0 for x in frequency_penalty])
else None
)
if any([x != 1.0 for x in temperature]):
do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample)
......@@ -269,6 +289,8 @@ class HeterogeneousNextTokenChooser:
_scores = self.watermark_processor(input_ids, _scores)
if self.repetition_processor is not None:
_scores = self.repetition_processor(input_ids, _scores)
if self.frequency_processor is not None:
_scores = self.frequency_processor(input_ids, _scores)
for warper in self.warpers:
_scores = warper(input_ids, _scores)
......@@ -316,7 +338,6 @@ class HeterogeneousNextTokenChooser:
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0:
if speculative_scores is not None:
# Medusa provided some scores
......@@ -338,6 +359,9 @@ class HeterogeneousNextTokenChooser:
if self.repetition_processor is not None:
self.repetition_processor = self.repetition_processor.filter(indices)
if self.frequency_processor is not None:
self.frequency_processor = self.frequency_processor.filter(indices)
filtered_warpers = []
for warper in self.warpers:
filtered_warper = warper.filter(indices)
......@@ -366,6 +390,7 @@ class HeterogeneousNextTokenChooser:
watermark=[pb_.watermark for pb_ in pb],
temperature=[pb_.temperature for pb_ in pb],
repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
top_k=[pb_.top_k for pb_ in pb],
top_p=[pb_.top_p for pb_ in pb],
typical_p=[pb_.typical_p for pb_ in pb],
......@@ -438,7 +463,10 @@ class HeterogeneousSampling:
def batch_top_tokens(
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor
top_n_tokens: List[int],
top_n_tokens_tensor: torch.Tensor,
logprobs: torch.Tensor,
accepted_ids: torch.Tensor,
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
"""Find the top n most likely tokens for a batch of generations.
......@@ -450,12 +478,15 @@ def batch_top_tokens(
if max_top_n == 0:
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
batch_size = accepted_ids.shape[0]
speculate_size = logprobs.shape[0] // batch_size
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
# Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)]
top_n_tokens = [
min(tok, logprobs.size(-1))
for tok in top_n_tokens
for _ in range(speculate_size)
]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset
......@@ -484,10 +515,10 @@ def batch_top_tokens(
for i, n_accepted_ids in enumerate(accepted_ids_list):
start = speculate_size * i
stop = speculate_size * (i + 1)
_top_indices = top_indices[start: stop]
_top_values = top_values[start: stop]
_top_n_ishes = top_n_ishes[start: stop]
_top_n_tokens = top_n_tokens[start: stop]
_top_indices = top_indices[start:stop]
_top_values = top_values[start:stop]
_top_n_ishes = top_n_ishes[start:stop]
_top_n_tokens = top_n_tokens[start:stop]
_top_indices = _top_indices[:n_accepted_ids]
_top_values = _top_values[:n_accepted_ids]
......@@ -497,7 +528,9 @@ def batch_top_tokens(
row_top_token_ids = []
row_top_token_logprobs = []
for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens):
for idxs, vals, n, req_n in zip(
_top_indices, _top_values, _top_n_ishes, _top_n_tokens
):
indices = idxs[:n] if req_n > 0 else []
values = vals[:n] if req_n > 0 else []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment