"examples/movement-pruning/vscode:/vscode.git/clone" did not exist on "cc746a502032434279702a5ffd2a41443c8def48"
Unverified Commit 865da84a authored by Sherman Siu's avatar Sherman Siu Committed by GitHub
Browse files

Add Epsilon- and Eta-Sampling (#21121)

* Add epsilon- and eta-sampling.

Add epsilon- and eta-sampling, following the official code from https://github.com/john-hewitt/truncation-sampling and adapting to be more configurable, as required by Huggingface transformers.

* Add unit tests for epsilon- and eta-sampling.

* Black: fix code formatting.

* Fix docstring spacing.

* Clean up newlines.

* Fix implementation bugs and their associated tests.

* Remove epsilon- and eta-sampling parameters from PretrainedConfig.

* Clarify and clean up the documentation.

* Remove parameters for PretrainedConfig test.
parent 02488103
......@@ -43,6 +43,8 @@ else:
"ConstrainedBeamSearchScorer",
]
_import_structure["logits_process"] = [
"EpsilonLogitsWarper",
"EtaLogitsWarper",
"ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor",
"HammingDiversityLogitsProcessor",
......@@ -162,6 +164,8 @@ if TYPE_CHECKING:
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
EpsilonLogitsWarper,
EtaLogitsWarper,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
......
......@@ -109,6 +109,18 @@ class GenerationConfig(PushToHubMixin):
generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that
add up to `typical_p` or higher are kept for generation. See [this
paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
epsilon_cutoff (`float`, *optional*, defaults to 0.0):
If set to float strictly between 0 and 1, only tokens with a conditional probability greater than
`epsilon_cutoff` will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the
size of the model. See [Truncation Sampling as Language Model
Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
eta_cutoff (`float`, *optional*, defaults to 0.0):
Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between
0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) *
exp(-entropy(softmax(next_token_logits)))`. The latter term is intuitively the expected next token
probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3,
depending on the size of the model. See [Truncation Sampling as Language Model
Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
diversity_penalty (`float`, *optional*, defaults to 0.0):
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled.
......@@ -223,6 +235,8 @@ class GenerationConfig(PushToHubMixin):
self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0)
self.typical_p = kwargs.pop("typical_p", 1.0)
self.epsilon_cutoff = kwargs.pop("epsilon_cutoff", 0.0)
self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0)
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.length_penalty = kwargs.pop("length_penalty", 1.0)
......
......@@ -138,7 +138,6 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
"""
def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: int):
for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip),
("min_new_tokens", min_new_tokens),
......@@ -152,7 +151,6 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
if new_tokens_length < self.min_new_tokens:
scores[:, self.eos_token_id] = -float("inf")
......@@ -297,7 +295,6 @@ class TypicalLogitsWarper(LogitsWarper):
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
......@@ -322,6 +319,90 @@ class TypicalLogitsWarper(LogitsWarper):
return scores
class EpsilonLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
Desmoothing](https://arxiv.org/abs/2210.15191) for more information.
Args:
epsilon (`float`):
If set to > 0, only the most tokens with probabilities `epsilon` or higher are kept for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
epsilon = float(epsilon)
if epsilon <= 0 or epsilon >= 1:
raise ValueError(f"`epsilon_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
min_tokens_to_keep = int(min_tokens_to_keep)
if min_tokens_to_keep < 1:
raise ValueError(
f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
)
self.epsilon = epsilon
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Determine which indices to remove
probabilities = scores.softmax(dim=-1)
indices_to_remove = probabilities < self.epsilon
# Keep the words with the 'min_tokens_to_keep'-highest probabilities
top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check
indices_to_remove = indices_to_remove & (scores < torch.topk(scores, top_k)[0][..., -1, None])
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class EtaLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs eta-sampling, i.e. calculates a dynamic cutoff `eta := min(epsilon, sqrt(epsilon,
e^-entropy(probabilities)))` and restricts to tokens with `prob >= eta`. Takes the largest min_tokens_to_keep
tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
Desmoothing](https://arxiv.org/abs/2210.15191) for more information.
Args:
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered."""
def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
epsilon = float(epsilon)
if epsilon <= 0 or epsilon >= 1:
raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
min_tokens_to_keep = int(min_tokens_to_keep)
if min_tokens_to_keep < 1:
raise ValueError(
f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
)
self.epsilon = torch.tensor(epsilon)
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Calculate the adaptive cutoff
probabilities = scores.softmax(dim=-1)
entropy = torch.distributions.Categorical(probs=probabilities).entropy()
eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None]
indices_to_remove = probabilities < eta
# Keep the words with the 'min_tokens_to_keep'-highest probabilities
top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check
indices_to_remove = indices_to_remove & (scores < torch.topk(scores, top_k)[0][..., -1, None])
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
......@@ -438,7 +519,6 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
"""
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
......
......@@ -39,6 +39,8 @@ from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScor
from .configuration_utils import GenerationConfig
from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
EpsilonLogitsWarper,
EtaLogitsWarper,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
......@@ -750,23 +752,22 @@ class GenerationMixin:
# all samplers can be found in `generation_utils_samplers.py`
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(
TopKLogitsWarper(
top_k=generation_config.top_k, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1)
)
)
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
warpers.append(
TopPLogitsWarper(
top_p=generation_config.top_p, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1)
)
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
warpers.append(
TypicalLogitsWarper(
mass=generation_config.typical_p, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1)
EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
warpers.append(
EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep)
)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
......@@ -1311,7 +1312,6 @@ class GenerationMixin:
)
elif is_contrastive_search_gen_mode:
if generation_config.num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
......@@ -1716,7 +1716,6 @@ class GenerationMixin:
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past_key_values") is None:
# prepare inputs
model_kwargs["use_cache"] = True
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
......
......@@ -28,6 +28,8 @@ if is_torch_available():
from transformers.generation import (
EncoderNoRepeatNGramLogitsProcessor,
EpsilonLogitsWarper,
EtaLogitsWarper,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
......@@ -288,6 +290,80 @@ class LogitsProcessorTest(unittest.TestCase):
# first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
def test_epsilon_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = torch.log(
torch.tensor(
[[0.87, 0.099, 0.001, 0.03], [0.4, 0.299, 0.101, 0.2]], device=torch_device, dtype=torch.float
)
)
epsilon_warp = EpsilonLogitsWarper(0.1)
filtered_dist = torch.exp(epsilon_warp(input_ids, dist))
# dist should be filtered to only keep values with proba >= 0.1
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = torch.tensor(
[[0.87, 0, 0, 0], [0.4, 0.299, 0.101, 0.2]], device=torch_device, dtype=torch.float
)
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
# check edge cases with negative and extreme logits
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
batch_size, 1
) - (vocab_size // 2)
# make ramp_logits more extreme
ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept
epsilon_warp = EpsilonLogitsWarper(5e-2, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = epsilon_warp(input_ids, ramp_logits)
# first batch should keep 3 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
def test_eta_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = torch.log(
torch.tensor([[0.0, 0.1, 0.8, 0.1], [0.01, 0.04, 0.9, 0.05]], device=torch_device, dtype=torch.float)
)
eta_warp = EtaLogitsWarper(0.0625)
filtered_dist = torch.exp(eta_warp(input_ids, dist))
# dist should be filtered to only keep values with proba >= min(0.0625, sqrt(0.0625) * e^-H(p))
# min(0.0625, 0.1320) is the cutoff for the first row and min(0.0625, 0.1644) is for the second
# where H is the entropy function and p is the probability vector.
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = torch.tensor(
[[0.0, 0.1, 0.8, 0.1], [0.0, 0.0, 0.9, 0.0]], device=torch_device, dtype=torch.float
)
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
# check edge cases with negative and extreme logits
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
batch_size, 1
) - (vocab_size // 2)
# make ramp_logits more extreme
ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept
eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = eta_warp(input_ids, ramp_logits)
# first batch should keep 2 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment