"docs/vscode:/vscode.git/clone" did not exist on "b67fd797bec56b59e1cd3ad54fa2783f7d7b7cbc"
Unverified Commit 005b5157 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: eta sampling numerical stability (#21676)

parent bb6a664e
...@@ -419,7 +419,7 @@ class EtaLogitsWarper(LogitsWarper): ...@@ -419,7 +419,7 @@ class EtaLogitsWarper(LogitsWarper):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Calculate the adaptive cutoff # Calculate the adaptive cutoff
probabilities = scores.softmax(dim=-1) probabilities = scores.softmax(dim=-1)
entropy = torch.distributions.Categorical(probs=probabilities).entropy() entropy = torch.distributions.Categorical(logits=scores).entropy()
eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None] eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None]
indices_to_remove = probabilities < eta indices_to_remove = probabilities < eta
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment