Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2a5663c2
Commit
2a5663c2
authored
Oct 31, 2019
by
Julien Chaumond
Browse files
Merge branch 'mataney-fix_top_k_top_p_filtering'
parents
fa735208
f96ce1c2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
14 deletions
+18
-14
examples/run_generation.py
examples/run_generation.py
+18
-14
No files found.
examples/run_generation.py
View file @
2a5663c2
...
...
@@ -79,13 +79,12 @@ def set_seed(args):
def
top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
0.0
,
filter_value
=-
float
(
'Inf'
)):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
logits: logits distribution shape (
batch size x
vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
assert
logits
.
dim
()
==
1
# batch size 1 for now - could be updated for more but the code would be less clear
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
if
top_k
>
0
:
# Remove all tokens with a probability less than the last token of the top-k
...
...
@@ -102,7 +101,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
sorted_indices_to_remove
[...,
1
:]
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
0
]
=
0
indices_to_remove
=
sorted_indices
[
sorted_indices_to_remove
]
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
dim
=
1
,
index
=
sorted_indices
,
src
=
sorted_indices_to_remove
)
logits
[
indices_to_remove
]
=
filter_value
return
logits
...
...
@@ -136,18 +136,19 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
inputs
[
"langs"
]
=
torch
.
tensor
([
xlm_lang
]
*
inputs
[
"input_ids"
].
shape
[
1
],
device
=
device
).
view
(
1
,
-
1
)
outputs
=
model
(
**
inputs
)
# Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
next_token_logits
=
outputs
[
0
][
0
,
-
1
,
:]
/
(
temperature
if
temperature
>
0
else
1.
)
next_token_logits
=
outputs
[
0
][
:
,
-
1
,
:]
/
(
temperature
if
temperature
>
0
else
1.
)
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for
_
in
set
(
generated
.
view
(
-
1
).
tolist
()):
next_token_logits
[
_
]
/=
repetition_penalty
# repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for
i
in
range
(
num_samples
):
for
_
in
set
(
generated
[
i
].
tolist
()):
next_token_logits
[
i
,
_
]
/=
repetition_penalty
filtered_logits
=
top_k_top_p_filtering
(
next_token_logits
,
top_k
=
top_k
,
top_p
=
top_p
)
if
temperature
==
0
:
#greedy sampling:
next_token
=
torch
.
argmax
(
filtered_logits
).
unsqueeze
(
0
)
if
temperature
==
0
:
#
greedy sampling:
next_token
=
torch
.
argmax
(
filtered_logits
,
dim
=-
1
).
unsqueeze
(
-
1
)
else
:
next_token
=
torch
.
multinomial
(
F
.
softmax
(
filtered_logits
,
dim
=-
1
),
num_samples
=
1
)
generated
=
torch
.
cat
((
generated
,
next_token
.
unsqueeze
(
0
)
),
dim
=
1
)
generated
=
torch
.
cat
((
generated
,
next_token
),
dim
=
1
)
return
generated
...
...
@@ -161,6 +162,7 @@ def main():
parser
.
add_argument
(
"--padding_text"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--xlm_lang"
,
type
=
str
,
default
=
""
,
help
=
"Optional language when used with the XLM model."
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--num_samples"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
"temperature of 0 implies greedy sampling"
)
parser
.
add_argument
(
"--repetition_penalty"
,
type
=
float
,
default
=
1.0
,
...
...
@@ -230,6 +232,7 @@ def main():
out
=
sample_sequence
(
model
=
model
,
context
=
context_tokens
,
num_samples
=
args
.
num_samples
,
length
=
args
.
length
,
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
...
...
@@ -241,12 +244,13 @@ def main():
xlm_lang
=
xlm_lang
,
device
=
args
.
device
,
)
out
=
out
[
0
,
len
(
context_tokens
):].
tolist
()
text
=
tokenizer
.
decode
(
o
ut
,
clean_up_tokenization_spaces
=
True
,
skip_special_tokens
=
True
)
out
=
out
[
:
,
len
(
context_tokens
):].
tolist
()
for
o
in
out
:
text
=
tokenizer
.
decode
(
o
,
clean_up_tokenization_spaces
=
True
)
text
=
text
[:
text
.
find
(
args
.
stop_token
)
if
args
.
stop_token
else
None
]
print
(
text
)
if
args
.
prompt
:
break
return
text
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment