Commit 018c270a authored by mshoeybi's avatar mshoeybi
Browse files

sampling

parent f1555799
...@@ -58,8 +58,8 @@ def modify_logits_for_top_p_filtering(logits, top_p): ...@@ -58,8 +58,8 @@ def modify_logits_for_top_p_filtering(logits, top_p):
def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0, def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0,
temperature=1.0, vocab_size=None): temperature=1.0, vocab_size=None):
""" Sample and update the logits and generate a token. """ Sample and update the logits and generate a token.
Note: logits has the dimension [b, v] where b is the batch size Note: logits has the dimension [b, s, v] where b is the batch size,
and v is the vocabulary size. s is the sequence length, and v is the vocabulary size.
Note: logits are modifed in place so the sampling modification Note: logits are modifed in place so the sampling modification
are reflected in the original full logits. are reflected in the original full logits.
If vocab_size is provided, we will make sure the sample that is If vocab_size is provided, we will make sure the sample that is
...@@ -68,11 +68,13 @@ def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0, ...@@ -68,11 +68,13 @@ def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0,
""" """
# Check logits for consistency. # Check logits for consistency.
assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.' assert logits.ndim == 3, 'expected the logits to be of [b, s, v] shape.'
assert logits.is_contiguous(), 'input logits should be contiguous.'
assert logits.type() == 'torch.cuda.FloatTensor', \ assert logits.type() == 'torch.cuda.FloatTensor', \
'input logits should be floats.' 'input logits should be floats.'
# We always index into the last index in s.
logits = logits[:, -1, :]
# Greedy is just simple argmax. # Greedy is just simple argmax.
if greedy: if greedy:
assert top_k == 0, 'cannot set both greedy and top-k samplings.' assert top_k == 0, 'cannot set both greedy and top-k samplings.'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment