Commit 554d1cc0 authored by mshoeybi's avatar mshoeybi
Browse files

sampling

parent 018c270a
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities sampling.
"""Sampling utilities.
Part of this code is inspired by:
- https://github.com/ari-holtzman/degen/blob/master/gen.py
- https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
......@@ -55,25 +55,23 @@ def modify_logits_for_top_p_filtering(logits, top_p):
def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0,
temperature=1.0, vocab_size=None):
""" Sample and update the logits and generate a token.
Note: logits has the dimension [b, s, v] where b is the batch size,
s is the sequence length, and v is the vocabulary size.
Note: logits are modifed in place so the sampling modification
are reflected in the original full logits.
def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0,
vocab_size=None):
""" Sample and generate a token.
Note: logits has the dimension [b, v] where b is the batch size
and v is the vocabulary size.
If vocab_size is provided, we will make sure the sample that is
generated is in [0, vocab-size). This will avoid out of vocabulary
generations due to padding.
"""
# Check logits for consistency.
assert logits.ndim == 3, 'expected the logits to be of [b, s, v] shape.'
assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.'
assert logits.type() == 'torch.cuda.FloatTensor', \
'input logits should be floats.'
# We always index into the last index in s.
logits = logits[:, -1, :]
# Clone so we do not modify the inputs,
logits = logits.clone()
# Greedy is just simple argmax.
if greedy:
......@@ -106,4 +104,4 @@ def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0,
if vocab_size:
samples = torch.clamp(samples, min=0, max=(vocab_size - 1))
return samples
return samples, logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment