Unverified Commit 69be658b authored by ljss's avatar ljss Committed by GitHub
Browse files

Support repetition_penalty (#1424)

parent beac8dd4
...@@ -50,12 +50,13 @@ class Sampler(nn.Module): ...@@ -50,12 +50,13 @@ class Sampler(nn.Module):
# Apply presence and frequency penalties. # Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata) output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0] assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties = _get_penalties( presence_penalties, frequency_penalties, repetition_penalties = (
input_metadata) _get_penalties(input_metadata))
assert len(presence_penalties) == logits.shape[0] assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0]
assert len(repetition_penalties) == logits.shape[0]
logits = _apply_penalties(logits, output_tokens, presence_penalties, logits = _apply_penalties(logits, output_tokens, presence_penalties,
frequency_penalties) frequency_penalties, repetition_penalties)
# Apply temperature scaling. # Apply temperature scaling.
temperatures = _get_temperatures(input_metadata) temperatures = _get_temperatures(input_metadata)
...@@ -134,14 +135,17 @@ def _prune_hidden_states( ...@@ -134,14 +135,17 @@ def _prune_hidden_states(
def _get_penalties( def _get_penalties(
input_metadata: InputMetadata) -> Tuple[List[float], List[float]]: input_metadata: InputMetadata
) -> Tuple[List[float], List[float], List[float]]:
# Collect the presence and frequency penalties. # Collect the presence and frequency penalties.
presence_penalties: List[float] = [] presence_penalties: List[float] = []
frequency_penalties: List[float] = [] frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty f = sampling_params.frequency_penalty
r = sampling_params.repetition_penalty
if (i < input_metadata.num_prompts if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
# NOTE: We do not apply presence and frequency penalties for the # NOTE: We do not apply presence and frequency penalties for the
...@@ -149,9 +153,11 @@ def _get_penalties( ...@@ -149,9 +153,11 @@ def _get_penalties(
prompt_len = input_metadata.prompt_lens[i] prompt_len = input_metadata.prompt_lens[i]
presence_penalties += [0] * (prompt_len - 1) presence_penalties += [0] * (prompt_len - 1)
frequency_penalties += [0] * (prompt_len - 1) frequency_penalties += [0] * (prompt_len - 1)
repetition_penalties += [1] * (prompt_len - 1)
presence_penalties += [p] * len(seq_ids) presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids) frequency_penalties += [f] * len(seq_ids)
return presence_penalties, frequency_penalties repetition_penalties += [r] * len(seq_ids)
return presence_penalties, frequency_penalties, repetition_penalties
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
...@@ -175,6 +181,7 @@ def _apply_penalties( ...@@ -175,6 +181,7 @@ def _apply_penalties(
output_tokens: List[List[int]], output_tokens: List[List[int]],
presence_penalties: List[float], presence_penalties: List[float],
frequency_penalties: List[float], frequency_penalties: List[float],
repetition_penalties: List[float],
) -> torch.Tensor: ) -> torch.Tensor:
num_seqs, vocab_size = logits.shape num_seqs, vocab_size = logits.shape
for i in range(num_seqs): for i in range(num_seqs):
...@@ -182,7 +189,9 @@ def _apply_penalties( ...@@ -182,7 +189,9 @@ def _apply_penalties(
continue continue
p = presence_penalties[i] p = presence_penalties[i]
f = frequency_penalties[i] f = frequency_penalties[i]
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS: r = repetition_penalties[i]
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS and abs(
r - 1.0) < _SAMPLING_EPS:
continue continue
break break
else: else:
...@@ -206,7 +215,11 @@ def _apply_penalties( ...@@ -206,7 +215,11 @@ def _apply_penalties(
bin_counts.scatter_add_(1, output_tokens_tensor, bin_counts.scatter_add_(1, output_tokens_tensor,
torch.ones_like(output_tokens_tensor)) torch.ones_like(output_tokens_tensor))
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin. bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
mask = bin_counts > 0
repetition_penalties = torch.tensor(repetition_penalties,
dtype=logits.dtype,
device=logits.device)
frequency_penalties = torch.tensor(frequency_penalties, frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype, dtype=logits.dtype,
device=logits.device) device=logits.device)
...@@ -214,10 +227,15 @@ def _apply_penalties( ...@@ -214,10 +227,15 @@ def _apply_penalties(
dtype=logits.dtype, dtype=logits.dtype,
device=logits.device) device=logits.device)
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties[~mask] = 1.0
logits = torch.where(logits > 0, logits / repetition_penalties,
logits * repetition_penalties)
# We follow the definition in OpenAI API. # We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details # Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0) logits -= presence_penalties.unsqueeze(dim=1) * mask
return logits return logits
......
...@@ -34,6 +34,10 @@ class SamplingParams: ...@@ -34,6 +34,10 @@ class SamplingParams:
frequency in the generated text so far. Values > 0 encourage the frequency in the generated text so far. Values > 0 encourage the
model to use new tokens, while values < 0 encourage the model to model to use new tokens, while values < 0 encourage the model to
repeat tokens. repeat tokens.
repetition_penalty: Float that penalizes new tokens based on whether
they appear in the generated text so far. Values > 1 encourage the
model to use new tokens, while values < 1 encourage the model to
repeat tokens.
temperature: Float that controls the randomness of the sampling. Lower temperature: Float that controls the randomness of the sampling. Lower
values make the model more deterministic, while higher values make values make the model more deterministic, while higher values make
the model more random. Zero means greedy sampling. the model more random. Zero means greedy sampling.
...@@ -75,6 +79,7 @@ class SamplingParams: ...@@ -75,6 +79,7 @@ class SamplingParams:
best_of: Optional[int] = None, best_of: Optional[int] = None,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
repetition_penalty: float = 1.0,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
...@@ -93,6 +98,7 @@ class SamplingParams: ...@@ -93,6 +98,7 @@ class SamplingParams:
self.best_of = best_of if best_of is not None else n self.best_of = best_of if best_of is not None else n
self.presence_penalty = presence_penalty self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty self.frequency_penalty = frequency_penalty
self.repetition_penalty = repetition_penalty
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
...@@ -136,6 +142,9 @@ class SamplingParams: ...@@ -136,6 +142,9 @@ class SamplingParams:
if not -2.0 <= self.frequency_penalty <= 2.0: if not -2.0 <= self.frequency_penalty <= 2.0:
raise ValueError("frequency_penalty must be in [-2, 2], got " raise ValueError("frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}.") f"{self.frequency_penalty}.")
if not 0.0 < self.repetition_penalty <= 2.0:
raise ValueError("repetition_penalty must be in (0, 2], got "
f"{self.repetition_penalty}.")
if self.temperature < 0.0: if self.temperature < 0.0:
raise ValueError( raise ValueError(
f"temperature must be non-negative, got {self.temperature}.") f"temperature must be non-negative, got {self.temperature}.")
...@@ -201,6 +210,7 @@ class SamplingParams: ...@@ -201,6 +210,7 @@ class SamplingParams:
f"best_of={self.best_of}, " f"best_of={self.best_of}, "
f"presence_penalty={self.presence_penalty}, " f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, " f"frequency_penalty={self.frequency_penalty}, "
f"repetition_penalty={self.repetition_penalty}, "
f"temperature={self.temperature}, " f"temperature={self.temperature}, "
f"top_p={self.top_p}, " f"top_p={self.top_p}, "
f"top_k={self.top_k}, " f"top_k={self.top_k}, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment