Unverified Commit 005b5157 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: eta sampling numerical stability (#21676)

parent bb6a664e
......@@ -419,7 +419,7 @@ class EtaLogitsWarper(LogitsWarper):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Calculate the adaptive cutoff
probabilities = scores.softmax(dim=-1)
entropy = torch.distributions.Categorical(probs=probabilities).entropy()
entropy = torch.distributions.Categorical(logits=scores).entropy()
eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None]
indices_to_remove = probabilities < eta
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment