Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f1555799
Commit
f1555799
authored
Sep 21, 2021
by
mshoeybi
Browse files
sampling tested
parent
297a5f33
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
24 deletions
+37
-24
megatron/inference/sampling.py
megatron/inference/sampling.py
+37
-24
No files found.
megatron/inference/sampling.py
View file @
f1555799
...
@@ -13,27 +13,27 @@
...
@@ -13,27 +13,27 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Utilities sampling."""
"""Utilities sampling.
Part of this code is inspired by:
- https://github.com/ari-holtzman/degen/blob/master/gen.py
- https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
"""
import
torch
import
torch
def
top_k_filtering
(
logits
,
top_k
):
"""Pick top-k logits."""
def
modify_logits_for_top_k_filtering
(
logits
,
top_k
):
"""Set the logits for none top-k values to -inf."""
filter_
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
filter_
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
.
masked_fill_
(
filter_
,
float
(
'-Inf'
))
logits
.
masked_fill_
(
filter_
,
float
(
'-Inf'
))
return
logits
def
top_p_filtering
(
logits
,
top_p
):
def
modify_logits_for_top_p_filtering
(
logits
,
top_p
):
"""Pick top-p logits.
"""Set the logits for none top-p values to -inf."""
Part of the code is adopted from:
https://huggingface.co/transformers/_modules/transformers/
\
generation_logits_process.html#TopPLogitsWarper
"""
# First sort and calculate cumulative sum of probabilities.
# First sort and calculate cumulative sum of probabilities.
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
...
@@ -41,50 +41,63 @@ def top_p_filtering(logits, top_p):
...
@@ -41,50 +41,63 @@ def top_p_filtering(logits, top_p):
# Filteration based on the cumulative sum.
# Filteration based on the cumulative sum.
filter_
=
cumulative_probs
>
top_p
filter_
=
cumulative_probs
>
top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_
[:,
1
:]
=
filter_
[:,
:
-
1
].
clone
()
# Make sure we at least have one token to select from.
# Make sure we at least have one token to select from.
filter_
[...,
0
]
=
0
filter_
[...,
0
]
=
0
# Fill in the filtered part
# Fill in the filtered part
filter_
=
filter_
.
scatter
(
1
,
sorted_indices
,
filter_
)
filter_
=
filter_
.
scatter
(
1
,
sorted_indices
,
filter_
)
logits
.
masked_fill_
(
filter_
,
float
(
'-Inf'
))
logits
.
masked_fill_
(
filter_
,
float
(
'-Inf'
))
return
logits
def
sample_logits
(
logits
,
greedy
=
False
,
top_k
=
0.0
,
top_p
=
0.0
,
temperature
=
1.0
,
vocab_size
=
None
):
def
sample_and_update_logits
(
logits
,
greedy
=
False
,
top_k
=
0
,
top_p
=
0.0
,
""" Sample the logit and generate a token.
temperature
=
1.0
,
vocab_size
=
None
):
""" Sample and update the logits and generate a token.
Note: logits has the dimension [b, v] where b is the batch size
Note: logits has the dimension [b, v] where b is the batch size
and v is the vocabulary size. """
and v is the vocabulary size.
Note: logits are modifed in place so the sampling modification
are reflected in the original full logits.
If vocab_size is provided, we will make sure the sample that is
generated is in [0, vocab-size). This will avoid out of vocabulary
generations due to padding.
"""
# Check logits for consistency.
# Check logits for consistency.
assert
logits
.
ndim
==
2
,
'expected the logits to be of [b, v] shape.'
assert
logits
.
ndim
==
2
,
'expected the logits to be of [b, v] shape.'
assert
logits
.
is_contiguous
(),
'input logits should be contiguous.'
assert
logits
.
is_contiguous
(),
'input logits should be contiguous.'
assert
logits
.
type
()
==
'torch.cuda.FloatTensor'
,
\
'input logits should be floats.'
# Greedy is just simple argmax.
# Greedy is just simple argmax.
if
greedy
:
if
greedy
:
assert
top_k
==
0.
0
,
'cannot set both greedy and top-k samplings.'
assert
top_k
==
0
,
'cannot set both greedy and top-k samplings.'
assert
top_p
==
0.0
,
'cannot set both greedy and top-p samplings.'
assert
top_p
==
0.0
,
'cannot set both greedy and top-p samplings.'
samples
=
torch
.
argmax
(
logits
,
dim
=-
1
)
samples
=
torch
.
argmax
(
logits
,
dim
=-
1
)
# Top-k or top-p sampling.
# Top-k or top-p sampling.
else
:
else
:
#
Convert to float so opts are more accurate and apply temperatur
e.
#
Apply temperature in plac
e.
logits
=
logits
.
float
()
/
temperature
logits
.
div_
(
temperature
)
if
top_k
>
0
:
if
top_k
>
0
:
assert
top_p
==
0.0
,
'cannot set both top-k and top-p samplings.'
assert
top_p
==
0.0
,
'cannot set both top-k and top-p samplings.'
assert
top_k
<=
logits
.
size
(
1
),
'top-k is larger than logit size.'
assert
top_k
<=
logits
.
size
(
1
),
'top-k is larger than logit size.'
if
vocab_size
:
if
vocab_size
:
assert
top_k
<
vocab_size
,
'top-k is larger than vocab size.'
assert
top_k
<
vocab_size
,
'top-k is larger than vocab size.'
logits
=
top_k_filtering
(
logits
,
top_k
)
modify_
logits
_for_
top_k_filtering
(
logits
,
top_k
)
el
se
:
el
if
top_p
>
0.0
:
assert
top_p
>
0.0
and
top_p
<=
1.0
,
'top-p should be in (0, 1].'
assert
top_p
<=
1.0
,
'top-p should be in (0, 1].'
logits
=
top_p_filtering
(
logits
,
top_p
)
modify_
logits
_for_
top_p_filtering
(
logits
,
top_p
)
# After filtering, we need to recalculate the distribution.
# After filtering, we need to recalculate the distribution.
logit
s
=
logits
.
softmax
(
dim
=-
1
)
prob
s
=
logits
.
softmax
(
dim
=-
1
)
samples
=
torch
.
multinomial
(
logit
s
,
num_samples
=
1
).
view
(
-
1
)
samples
=
torch
.
multinomial
(
prob
s
,
num_samples
=
1
).
view
(
-
1
)
# If vocab size is provided, make sure the samples are in
# If vocab size is provided, make sure the samples are in
# in the range [0, vocab-size).
# in the range [0, vocab-size).
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment