Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
8a326bbc
Commit
8a326bbc
authored
Aug 29, 2023
by
Tri Dao
Browse files
[Gen] Minor fix to modify logits for top_p
parent
1d817a8f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+3
-3
No files found.
flash_attn/utils/generation.py
View file @
8a326bbc
...
@@ -32,7 +32,7 @@ class InferenceParams:
...
@@ -32,7 +32,7 @@ class InferenceParams:
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def
modify_logits_for_top_k_filtering
(
logits
,
top_k
):
def
modify_logits_for_top_k_filtering
(
logits
,
top_k
):
"""Set the logits for none top-k values to -inf."""
"""Set the logits for none top-k values to -inf.
Done in-place.
"""
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
.
masked_fill_
(
indices_to_remove
,
float
(
"-Inf"
))
logits
.
masked_fill_
(
indices_to_remove
,
float
(
"-Inf"
))
...
@@ -40,7 +40,7 @@ def modify_logits_for_top_k_filtering(logits, top_k):
...
@@ -40,7 +40,7 @@ def modify_logits_for_top_k_filtering(logits, top_k):
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def
modify_logits_for_top_p_filtering
(
logits
,
top_p
):
def
modify_logits_for_top_p_filtering
(
logits
,
top_p
):
"""Set the logits for none top-p values to -inf."""
"""Set the logits for none top-p values to -inf.
Done in-place.
"""
if
top_p
<=
0.0
or
top_p
>=
1.0
:
if
top_p
<=
0.0
or
top_p
>=
1.0
:
return
return
# First sort and calculate cumulative sum of probabilities.
# First sort and calculate cumulative sum of probabilities.
...
@@ -52,7 +52,7 @@ def modify_logits_for_top_p_filtering(logits, top_p):
...
@@ -52,7 +52,7 @@ def modify_logits_for_top_p_filtering(logits, top_p):
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
1
,
sorted_indices
,
sorted_indices_to_remove
1
,
sorted_indices
,
sorted_indices_to_remove
)
)
logits
=
logits
.
masked_fill
(
indices_to_remove
,
float
(
"-inf"
))
logits
.
masked_fill
_
(
indices_to_remove
,
float
(
"-inf"
))
def
sample
(
logits
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
def
sample
(
logits
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment