Unverified Commit 4b4b0792 authored by Catalin Voss's avatar Catalin Voss Committed by GitHub
Browse files

Fix top k generation for k != 0

parent 2152bfea
...@@ -16,11 +16,17 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa ...@@ -16,11 +16,17 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def top_k_logits(logits, k): def top_k_logits(logits, k):
"""
Masks everything but the k top entries as -infinity (1e10).
Used to mask logits such that e^-infinity -> 0 won't contribute to the
sum of the denominator.
"""
if k == 0: if k == 0:
return logits return logits
values, _ = torch.topk(logits, k) else:
min_values = values[:, -1] values = torch.topk(logits, k)[0]
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits) batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits)
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True): def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True):
if start_token is None: if start_token is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment