Commit 018c270a authored by mshoeybi's avatar mshoeybi
Browse files

sampling

parent f1555799
......@@ -58,8 +58,8 @@ def modify_logits_for_top_p_filtering(logits, top_p):
def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0,
temperature=1.0, vocab_size=None):
""" Sample and update the logits and generate a token.
Note: logits has the dimension [b, v] where b is the batch size
and v is the vocabulary size.
Note: logits has the dimension [b, s, v] where b is the batch size,
s is the sequence length, and v is the vocabulary size.
Note: logits are modifed in place so the sampling modification
are reflected in the original full logits.
If vocab_size is provided, we will make sure the sample that is
......@@ -68,11 +68,13 @@ def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0,
"""
# Check logits for consistency.
assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.'
assert logits.is_contiguous(), 'input logits should be contiguous.'
assert logits.ndim == 3, 'expected the logits to be of [b, s, v] shape.'
assert logits.type() == 'torch.cuda.FloatTensor', \
'input logits should be floats.'
# We always index into the last index in s.
logits = logits[:, -1, :]
# Greedy is just simple argmax.
if greedy:
assert top_k == 0, 'cannot set both greedy and top-k samplings.'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment