Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f96ce1c2
Commit
f96ce1c2
authored
Oct 31, 2019
by
Julien Chaumond
Browse files
[run_generation] Fix generation with batch_size>1
parent
3c1b6f59
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
6 deletions
+7
-6
examples/run_generation.py
examples/run_generation.py
+7
-6
No files found.
examples/run_generation.py
View file @
f96ce1c2
...
...
@@ -79,7 +79,7 @@ def set_seed(args):
def
top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
0.0
,
filter_value
=-
float
(
'Inf'
)):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
logits: logits distribution shape (
batch size x
vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
...
...
@@ -138,13 +138,14 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
outputs
=
model
(
**
inputs
)
# Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
next_token_logits
=
outputs
[
0
][:,
-
1
,
:]
/
(
temperature
if
temperature
>
0
else
1.
)
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for
_
in
set
(
generated
.
view
(
-
1
).
tolist
()):
next_token_logits
[
_
]
/=
repetition_penalty
# repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for
i
in
range
(
num_samples
):
for
_
in
set
(
generated
[
i
].
tolist
()):
next_token_logits
[
i
,
_
]
/=
repetition_penalty
filtered_logits
=
top_k_top_p_filtering
(
next_token_logits
,
top_k
=
top_k
,
top_p
=
top_p
)
if
temperature
==
0
:
#greedy sampling:
next_token
=
torch
.
argmax
(
filtered_logits
).
unsqueeze
(
0
)
if
temperature
==
0
:
#
greedy sampling:
next_token
=
torch
.
argmax
(
filtered_logits
,
dim
=-
1
).
unsqueeze
(
-
1
)
else
:
next_token
=
torch
.
multinomial
(
F
.
softmax
(
filtered_logits
,
dim
=-
1
),
num_samples
=
1
)
generated
=
torch
.
cat
((
generated
,
next_token
),
dim
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment