Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
554d1cc0
Commit
554d1cc0
authored
Sep 22, 2021
by
mshoeybi
Browse files
sampling
parent
018c270a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
12 deletions
+10
-12
megatron/inference/sampling.py
megatron/inference/sampling.py
+10
-12
No files found.
megatron/inference/sampling.py
View file @
554d1cc0
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
Utilities sampling
.
"""
Sampling utilities
.
Part of this code is inspired by:
Part of this code is inspired by:
- https://github.com/ari-holtzman/degen/blob/master/gen.py
- https://github.com/ari-holtzman/degen/blob/master/gen.py
- https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
- https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
...
@@ -55,25 +55,23 @@ def modify_logits_for_top_p_filtering(logits, top_p):
...
@@ -55,25 +55,23 @@ def modify_logits_for_top_p_filtering(logits, top_p):
def
sample_and_update_logits
(
logits
,
greedy
=
False
,
top_k
=
0
,
top_p
=
0.0
,
def
sample
(
logits
,
greedy
=
False
,
top_k
=
0
,
top_p
=
0.0
,
temperature
=
1.0
,
temperature
=
1.0
,
vocab_size
=
None
):
vocab_size
=
None
):
""" Sample and update the logits and generate a token.
""" Sample and generate a token.
Note: logits has the dimension [b, s, v] where b is the batch size,
Note: logits has the dimension [b, v] where b is the batch size
s is the sequence length, and v is the vocabulary size.
and v is the vocabulary size.
Note: logits are modifed in place so the sampling modification
are reflected in the original full logits.
If vocab_size is provided, we will make sure the sample that is
If vocab_size is provided, we will make sure the sample that is
generated is in [0, vocab-size). This will avoid out of vocabulary
generated is in [0, vocab-size). This will avoid out of vocabulary
generations due to padding.
generations due to padding.
"""
"""
# Check logits for consistency.
# Check logits for consistency.
assert
logits
.
ndim
==
3
,
'expected the logits to be of [b,
s,
v] shape.'
assert
logits
.
ndim
==
2
,
'expected the logits to be of [b, v] shape.'
assert
logits
.
type
()
==
'torch.cuda.FloatTensor'
,
\
assert
logits
.
type
()
==
'torch.cuda.FloatTensor'
,
\
'input logits should be floats.'
'input logits should be floats.'
#
We always index into the last index in s.
#
Clone so we do not modify the inputs,
logits
=
logits
[:,
-
1
,
:]
logits
=
logits
.
clone
()
# Greedy is just simple argmax.
# Greedy is just simple argmax.
if
greedy
:
if
greedy
:
...
@@ -106,4 +104,4 @@ def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0,
...
@@ -106,4 +104,4 @@ def sample_and_update_logits(logits, greedy=False, top_k=0, top_p=0.0,
if
vocab_size
:
if
vocab_size
:
samples
=
torch
.
clamp
(
samples
,
min
=
0
,
max
=
(
vocab_size
-
1
))
samples
=
torch
.
clamp
(
samples
,
min
=
0
,
max
=
(
vocab_size
-
1
))
return
samples
return
samples
,
logits
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment