Commit f1555799 authored by mshoeybi's avatar mshoeybi
Browse files

sampling tested

parent 297a5f33
...@@ -13,27 +13,27 @@ ...@@ -13,27 +13,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Utilities sampling.""" """Utilities sampling.
Part of this code is inspired by:
- https://github.com/ari-holtzman/degen/blob/master/gen.py
- https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
"""
import torch import torch
def top_k_filtering(logits, top_k):
"""Pick top-k logits.""" def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(filter_, float('-Inf')) logits.masked_fill_(filter_, float('-Inf'))
return logits
def top_p_filtering(logits, top_p): def modify_logits_for_top_p_filtering(logits, top_p):
"""Pick top-p logits. """Set the logits for none top-p values to -inf."""
Part of the code is adopted from:
https://huggingface.co/transformers/_modules/transformers/\
generation_logits_process.html#TopPLogitsWarper
"""
# First sort and calculate cumulative sum of probabilities. # First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_logits, sorted_indices = torch.sort(logits, descending=True)
...@@ -41,50 +41,63 @@ def top_p_filtering(logits, top_p): ...@@ -41,50 +41,63 @@ def top_p_filtering(logits, top_p):
# Filteration based on the cumulative sum. # Filteration based on the cumulative sum.
filter_ = cumulative_probs > top_p filter_ = cumulative_probs > top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_[:, 1:] = filter_[:, :-1].clone()
# Make sure we at least have one token to select from. # Make sure we at least have one token to select from.
filter_[..., 0] = 0 filter_[..., 0] = 0
# Fill in the filtered part # Fill in the filtered part
filter_ = filter_.scatter(1, sorted_indices, filter_) filter_ = filter_.scatter(1, sorted_indices, filter_)
logits.masked_fill_(filter_, float('-Inf')) logits.masked_fill_(filter_, float('-Inf'))
return logits
def sample_logits(logits, greedy=False, top_k=0.0, top_p=0.0, temperature=1.0,
vocab_size=None): def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0,
""" Sample the logit and generate a token. temperature=1.0, vocab_size=None):
""" Sample and update the logits and generate a token.
Note: logits has the dimension [b, v] where b is the batch size Note: logits has the dimension [b, v] where b is the batch size
and v is the vocabulary size. """ and v is the vocabulary size.
Note: logits are modifed in place so the sampling modification
are reflected in the original full logits.
If vocab_size is provided, we will make sure the sample that is
generated is in [0, vocab-size). This will avoid out of vocabulary
generations due to padding.
"""
# Check logits for consistency. # Check logits for consistency.
assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.' assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.'
assert logits.is_contiguous(), 'input logits should be contiguous.' assert logits.is_contiguous(), 'input logits should be contiguous.'
assert logits.type() == 'torch.cuda.FloatTensor', \
'input logits should be floats.'
# Greedy is just simple argmax. # Greedy is just simple argmax.
if greedy: if greedy:
assert top_k == 0.0, 'cannot set both greedy and top-k samplings.' assert top_k == 0, 'cannot set both greedy and top-k samplings.'
assert top_p == 0.0, 'cannot set both greedy and top-p samplings.' assert top_p == 0.0, 'cannot set both greedy and top-p samplings.'
samples = torch.argmax(logits, dim=-1) samples = torch.argmax(logits, dim=-1)
# Top-k or top-p sampling. # Top-k or top-p sampling.
else: else:
# Convert to float so opts are more accurate and apply temperature. # Apply temperature in place.
logits = logits.float() / temperature logits.div_(temperature)
if top_k > 0: if top_k > 0:
assert top_p == 0.0, 'cannot set both top-k and top-p samplings.' assert top_p == 0.0, 'cannot set both top-k and top-p samplings.'
assert top_k <= logits.size(1), 'top-k is larger than logit size.' assert top_k <= logits.size(1), 'top-k is larger than logit size.'
if vocab_size: if vocab_size:
assert top_k < vocab_size, 'top-k is larger than vocab size.' assert top_k < vocab_size, 'top-k is larger than vocab size.'
logits = top_k_filtering(logits, top_k) modify_logits_for_top_k_filtering(logits, top_k)
else: elif top_p > 0.0:
assert top_p > 0.0 and top_p <= 1.0, 'top-p should be in (0, 1].' assert top_p <= 1.0, 'top-p should be in (0, 1].'
logits = top_p_filtering(logits, top_p) modify_logits_for_top_p_filtering(logits, top_p)
# After filtering, we need to recalculate the distribution. # After filtering, we need to recalculate the distribution.
logits = logits.softmax(dim=-1) probs = logits.softmax(dim=-1)
samples = torch.multinomial(logits, num_samples=1).view(-1) samples = torch.multinomial(probs, num_samples=1).view(-1)
# If vocab size is provided, make sure the samples are in # If vocab size is provided, make sure the samples are in
# in the range [0, vocab-size). # in the range [0, vocab-size).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment