Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
018c270a
Commit
018c270a
authored
Sep 21, 2021
by
mshoeybi
Browse files
sampling
parent
f1555799
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
4 deletions
+6
-4
megatron/inference/sampling.py
megatron/inference/sampling.py
+6
-4
No files found.
megatron/inference/sampling.py
View file @
018c270a
...
@@ -58,8 +58,8 @@ def modify_logits_for_top_p_filtering(logits, top_p):
...
@@ -58,8 +58,8 @@ def modify_logits_for_top_p_filtering(logits, top_p):
def
sample_and_update_logits
(
logits
,
greedy
=
False
,
top_k
=
0
,
top_p
=
0.0
,
def
sample_and_update_logits
(
logits
,
greedy
=
False
,
top_k
=
0
,
top_p
=
0.0
,
temperature
=
1.0
,
vocab_size
=
None
):
temperature
=
1.0
,
vocab_size
=
None
):
""" Sample and update the logits and generate a token.
""" Sample and update the logits and generate a token.
Note: logits has the dimension [b, v] where b is the batch size
Note: logits has the dimension [b,
s,
v] where b is the batch size
,
and v is the vocabulary size.
s is the sequence length,
and v is the vocabulary size.
Note: logits are modifed in place so the sampling modification
Note: logits are modifed in place so the sampling modification
are reflected in the original full logits.
are reflected in the original full logits.
If vocab_size is provided, we will make sure the sample that is
If vocab_size is provided, we will make sure the sample that is
...
@@ -68,11 +68,13 @@ def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0,
...
@@ -68,11 +68,13 @@ def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0,
"""
"""
# Check logits for consistency.
# Check logits for consistency.
assert
logits
.
ndim
==
2
,
'expected the logits to be of [b, v] shape.'
assert
logits
.
ndim
==
3
,
'expected the logits to be of [b, s, v] shape.'
assert
logits
.
is_contiguous
(),
'input logits should be contiguous.'
assert
logits
.
type
()
==
'torch.cuda.FloatTensor'
,
\
assert
logits
.
type
()
==
'torch.cuda.FloatTensor'
,
\
'input logits should be floats.'
'input logits should be floats.'
# We always index into the last index in s.
logits
=
logits
[:,
-
1
,
:]
# Greedy is just simple argmax.
# Greedy is just simple argmax.
if
greedy
:
if
greedy
:
assert
top_k
==
0
,
'cannot set both greedy and top-k samplings.'
assert
top_k
==
0
,
'cannot set both greedy and top-k samplings.'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment