Unverified Commit 09b7c26b authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): add frequency penalty (#1541)

parent 39af000c
...@@ -118,6 +118,62 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): ...@@ -118,6 +118,62 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
return None return None
class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
r"""
Frequency penalty as defined by OpenAI
Args:
penalty (`float`):
The parameter for frequency penalty. 0.0 means no penalty.
"""
def __init__(self, penalty: float):
self.penalty = penalty
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(
score < 0, score * self.penalty, score / self.penalty
)
return scores.scatter_add_(1, input_ids, score)
class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
r"""
Frequency penalty as defined by OpenAI
Args:
frequency_penalty (`List[float]`):
The parameter for frequency penalty. 0.0 means no penalty.
"""
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
self.penalty = penalty
self.penalty_tensor = torch.tensor(
penalty, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score = -torch.where(
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
)
return scores.scatter_add_(1, input_ids, score)
def filter(self, indices):
self.penalty = [self.penalty[i] for i in indices]
if any([x != 0.0 for x in self.penalty]):
self.penalty_tensor = self.penalty_tensor[indices]
return self
return None
class HeterogeneousTemperatureLogitsWarper: class HeterogeneousTemperatureLogitsWarper:
r""" r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution). [`LogitsWarper`] for temperature (exponential scaling output probability distribution).
......
import re import re
from typing import Callable, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason from text_generation_server.pb.generate_pb2 import FinishReason
from text_generation_server.utils.logits_process import ( from text_generation_server.utils.logits_process import (
FrequencyPenaltyLogitsProcessor,
HeterogeneousProcessorWrapper, HeterogeneousProcessorWrapper,
HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousRepetitionPenaltyLogitsProcessor,
HeterogeneousFrequencyPenaltyLogitsProcessor,
HeterogeneousTemperatureLogitsWarper, HeterogeneousTemperatureLogitsWarper,
HeterogeneousTopKLogitsWarper, HeterogeneousTopKLogitsWarper,
HeterogeneousTopPLogitsWarper, HeterogeneousTopPLogitsWarper,
...@@ -23,6 +25,7 @@ class NextTokenChooser: ...@@ -23,6 +25,7 @@ class NextTokenChooser:
watermark=False, watermark=False,
temperature=1.0, temperature=1.0,
repetition_penalty=1.0, repetition_penalty=1.0,
frequency_penalty=0.0,
top_k=None, top_k=None,
top_p=None, top_p=None,
typical_p=None, typical_p=None,
...@@ -35,7 +38,12 @@ class NextTokenChooser: ...@@ -35,7 +38,12 @@ class NextTokenChooser:
) )
self.repetition_processor = ( self.repetition_processor = (
RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
if repetition_penalty if repetition_penalty and repetition_penalty != 1.0
else None
)
self.frequency_processor = (
FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)
if frequency_penalty and frequency_penalty != 0.0
else None else None
) )
...@@ -60,6 +68,8 @@ class NextTokenChooser: ...@@ -60,6 +68,8 @@ class NextTokenChooser:
scores = self.watermark_processor(input_ids, scores) scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None: if self.repetition_processor is not None:
scores = self.repetition_processor(input_ids, scores) scores = self.repetition_processor(input_ids, scores)
if self.frequency_processor is not None:
scores = self.frequency_processor(input_ids, scores)
if self.static_warper is None: if self.static_warper is None:
next_logprob = torch.log_softmax(scores, -1) next_logprob = torch.log_softmax(scores, -1)
...@@ -80,6 +90,7 @@ class NextTokenChooser: ...@@ -80,6 +90,7 @@ class NextTokenChooser:
watermark=pb.watermark, watermark=pb.watermark,
temperature=pb.temperature, temperature=pb.temperature,
repetition_penalty=pb.repetition_penalty, repetition_penalty=pb.repetition_penalty,
frequency_penalty=pb.frequency_penalty,
top_k=pb.top_k, top_k=pb.top_k,
top_p=pb.top_p, top_p=pb.top_p,
typical_p=pb.typical_p, typical_p=pb.typical_p,
...@@ -184,6 +195,7 @@ class HeterogeneousNextTokenChooser: ...@@ -184,6 +195,7 @@ class HeterogeneousNextTokenChooser:
watermark: List[bool], watermark: List[bool],
temperature: List[float], temperature: List[float],
repetition_penalty: List[float], repetition_penalty: List[float],
frequency_penalty: List[float],
top_k: List[int], top_k: List[int],
top_p: List[float], top_p: List[float],
typical_p: List[float], typical_p: List[float],
...@@ -212,6 +224,14 @@ class HeterogeneousNextTokenChooser: ...@@ -212,6 +224,14 @@ class HeterogeneousNextTokenChooser:
else None else None
) )
self.frequency_processor = (
HeterogeneousFrequencyPenaltyLogitsProcessor(
frequency_penalty, dtype, device
)
if any([x != 0.0 for x in frequency_penalty])
else None
)
if any([x != 1.0 for x in temperature]): if any([x != 1.0 for x in temperature]):
do_sample = [ do_sample = [
sample or x != 1.0 for x, sample in zip(temperature, do_sample) sample or x != 1.0 for x, sample in zip(temperature, do_sample)
...@@ -269,6 +289,8 @@ class HeterogeneousNextTokenChooser: ...@@ -269,6 +289,8 @@ class HeterogeneousNextTokenChooser:
_scores = self.watermark_processor(input_ids, _scores) _scores = self.watermark_processor(input_ids, _scores)
if self.repetition_processor is not None: if self.repetition_processor is not None:
_scores = self.repetition_processor(input_ids, _scores) _scores = self.repetition_processor(input_ids, _scores)
if self.frequency_processor is not None:
_scores = self.frequency_processor(input_ids, _scores)
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
...@@ -316,7 +338,6 @@ class HeterogeneousNextTokenChooser: ...@@ -316,7 +338,6 @@ class HeterogeneousNextTokenChooser:
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
if speculate > 0: if speculate > 0:
if speculative_scores is not None: if speculative_scores is not None:
# Medusa provided some scores # Medusa provided some scores
...@@ -338,6 +359,9 @@ class HeterogeneousNextTokenChooser: ...@@ -338,6 +359,9 @@ class HeterogeneousNextTokenChooser:
if self.repetition_processor is not None: if self.repetition_processor is not None:
self.repetition_processor = self.repetition_processor.filter(indices) self.repetition_processor = self.repetition_processor.filter(indices)
if self.frequency_processor is not None:
self.frequency_processor = self.frequency_processor.filter(indices)
filtered_warpers = [] filtered_warpers = []
for warper in self.warpers: for warper in self.warpers:
filtered_warper = warper.filter(indices) filtered_warper = warper.filter(indices)
...@@ -366,6 +390,7 @@ class HeterogeneousNextTokenChooser: ...@@ -366,6 +390,7 @@ class HeterogeneousNextTokenChooser:
watermark=[pb_.watermark for pb_ in pb], watermark=[pb_.watermark for pb_ in pb],
temperature=[pb_.temperature for pb_ in pb], temperature=[pb_.temperature for pb_ in pb],
repetition_penalty=[pb_.repetition_penalty for pb_ in pb], repetition_penalty=[pb_.repetition_penalty for pb_ in pb],
frequency_penalty=[pb_.frequency_penalty for pb_ in pb],
top_k=[pb_.top_k for pb_ in pb], top_k=[pb_.top_k for pb_ in pb],
top_p=[pb_.top_p for pb_ in pb], top_p=[pb_.top_p for pb_ in pb],
typical_p=[pb_.typical_p for pb_ in pb], typical_p=[pb_.typical_p for pb_ in pb],
...@@ -438,7 +463,10 @@ class HeterogeneousSampling: ...@@ -438,7 +463,10 @@ class HeterogeneousSampling:
def batch_top_tokens( def batch_top_tokens(
top_n_tokens: List[int], top_n_tokens_tensor: torch.Tensor, logprobs: torch.Tensor, accepted_ids: torch.Tensor top_n_tokens: List[int],
top_n_tokens_tensor: torch.Tensor,
logprobs: torch.Tensor,
accepted_ids: torch.Tensor,
) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:
"""Find the top n most likely tokens for a batch of generations. """Find the top n most likely tokens for a batch of generations.
...@@ -450,12 +478,15 @@ def batch_top_tokens( ...@@ -450,12 +478,15 @@ def batch_top_tokens(
if max_top_n == 0: if max_top_n == 0:
return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens) return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)
batch_size = accepted_ids.shape[0] batch_size = accepted_ids.shape[0]
speculate_size = logprobs.shape[0] // batch_size speculate_size = logprobs.shape[0] // batch_size
top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size) top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)
# Ensure top_n doesn't exceed vocab size # Ensure top_n doesn't exceed vocab size
top_n_tokens = [min(tok, logprobs.size(-1)) for tok in top_n_tokens for _ in range(speculate_size)] top_n_tokens = [
min(tok, logprobs.size(-1))
for tok in top_n_tokens
for _ in range(speculate_size)
]
# Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2 # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2
# Sorted topk is faster than torch.sort() since we only need a small subset # Sorted topk is faster than torch.sort() since we only need a small subset
...@@ -484,10 +515,10 @@ def batch_top_tokens( ...@@ -484,10 +515,10 @@ def batch_top_tokens(
for i, n_accepted_ids in enumerate(accepted_ids_list): for i, n_accepted_ids in enumerate(accepted_ids_list):
start = speculate_size * i start = speculate_size * i
stop = speculate_size * (i + 1) stop = speculate_size * (i + 1)
_top_indices = top_indices[start: stop] _top_indices = top_indices[start:stop]
_top_values = top_values[start: stop] _top_values = top_values[start:stop]
_top_n_ishes = top_n_ishes[start: stop] _top_n_ishes = top_n_ishes[start:stop]
_top_n_tokens = top_n_tokens[start: stop] _top_n_tokens = top_n_tokens[start:stop]
_top_indices = _top_indices[:n_accepted_ids] _top_indices = _top_indices[:n_accepted_ids]
_top_values = _top_values[:n_accepted_ids] _top_values = _top_values[:n_accepted_ids]
...@@ -497,7 +528,9 @@ def batch_top_tokens( ...@@ -497,7 +528,9 @@ def batch_top_tokens(
row_top_token_ids = [] row_top_token_ids = []
row_top_token_logprobs = [] row_top_token_logprobs = []
for idxs, vals, n, req_n in zip(_top_indices, _top_values, _top_n_ishes, _top_n_tokens): for idxs, vals, n, req_n in zip(
_top_indices, _top_values, _top_n_ishes, _top_n_tokens
):
indices = idxs[:n] if req_n > 0 else [] indices = idxs[:n] if req_n > 0 else []
values = vals[:n] if req_n > 0 else [] values = vals[:n] if req_n > 0 else []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment