Commit 8a326bbc authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Minor fix to modify logits for top_p

parent 1d817a8f
......@@ -32,7 +32,7 @@ class InferenceParams:
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
"""Set the logits for none top-k values to -inf. Done in-place."""
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(indices_to_remove, float("-Inf"))
......@@ -40,7 +40,7 @@ def modify_logits_for_top_k_filtering(logits, top_k):
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
"""Set the logits for none top-p values to -inf. Done in-place."""
if top_p <= 0.0 or top_p >= 1.0:
return
# First sort and calculate cumulative sum of probabilities.
......@@ -52,7 +52,7 @@ def modify_logits_for_top_p_filtering(logits, top_p):
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, float("-inf"))
logits.masked_fill_(indices_to_remove, float("-inf"))
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment