Commit 554d1cc0 authored by mshoeybi's avatar mshoeybi
Browse files

sampling

parent 018c270a
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Utilities sampling. """Sampling utilities.
Part of this code is inspired by: Part of this code is inspired by:
- https://github.com/ari-holtzman/degen/blob/master/gen.py - https://github.com/ari-holtzman/degen/blob/master/gen.py
- https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html - https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
...@@ -55,25 +55,23 @@ def modify_logits_for_top_p_filtering(logits, top_p): ...@@ -55,25 +55,23 @@ def modify_logits_for_top_p_filtering(logits, top_p):
def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0, def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0,
temperature=1.0, vocab_size=None): vocab_size=None):
""" Sample and update the logits and generate a token. """ Sample and generate a token.
Note: logits has the dimension [b, s, v] where b is the batch size, Note: logits has the dimension [b, v] where b is the batch size
s is the sequence length, and v is the vocabulary size. and v is the vocabulary size.
Note: logits are modifed in place so the sampling modification
are reflected in the original full logits.
If vocab_size is provided, we will make sure the sample that is If vocab_size is provided, we will make sure the sample that is
generated is in [0, vocab-size). This will avoid out of vocabulary generated is in [0, vocab-size). This will avoid out of vocabulary
generations due to padding. generations due to padding.
""" """
# Check logits for consistency. # Check logits for consistency.
assert logits.ndim == 3, 'expected the logits to be of [b, s, v] shape.' assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.'
assert logits.type() == 'torch.cuda.FloatTensor', \ assert logits.type() == 'torch.cuda.FloatTensor', \
'input logits should be floats.' 'input logits should be floats.'
# We always index into the last index in s. # Clone so we do not modify the inputs,
logits = logits[:, -1, :] logits = logits.clone()
# Greedy is just simple argmax. # Greedy is just simple argmax.
if greedy: if greedy:
...@@ -106,4 +104,4 @@ def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0, ...@@ -106,4 +104,4 @@ def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0,
if vocab_size: if vocab_size:
samples = torch.clamp(samples, min=0, max=(vocab_size - 1)) samples = torch.clamp(samples, min=0, max=(vocab_size - 1))
return samples return samples, logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment