Commit 8a326bbc authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Minor fix to modify logits for top_p

parent 1d817a8f
...@@ -32,7 +32,7 @@ class InferenceParams: ...@@ -32,7 +32,7 @@ class InferenceParams:
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def modify_logits_for_top_k_filtering(logits, top_k): def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf.""" """Set the logits for none top-k values to -inf. Done in-place."""
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(indices_to_remove, float("-Inf")) logits.masked_fill_(indices_to_remove, float("-Inf"))
...@@ -40,7 +40,7 @@ def modify_logits_for_top_k_filtering(logits, top_k): ...@@ -40,7 +40,7 @@ def modify_logits_for_top_k_filtering(logits, top_k):
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p): def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf.""" """Set the logits for none top-p values to -inf. Done in-place."""
if top_p <= 0.0 or top_p >= 1.0: if top_p <= 0.0 or top_p >= 1.0:
return return
# First sort and calculate cumulative sum of probabilities. # First sort and calculate cumulative sum of probabilities.
...@@ -52,7 +52,7 @@ def modify_logits_for_top_p_filtering(logits, top_p): ...@@ -52,7 +52,7 @@ def modify_logits_for_top_p_filtering(logits, top_p):
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove 1, sorted_indices, sorted_indices_to_remove
) )
logits = logits.masked_fill(indices_to_remove, float("-inf")) logits.masked_fill_(indices_to_remove, float("-inf"))
def sample(logits, top_k=1, top_p=0.0, temperature=1.0): def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment