Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a9f24a16
Commit
a9f24a16
authored
Sep 25, 2019
by
mataney
Browse files
[FIX] fix run_generation.py to work with batch_size > 1
parent
7c0f2d0a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
7 deletions
+10
-7
examples/run_generation.py
examples/run_generation.py
+10
-7
No files found.
examples/run_generation.py
View file @
a9f24a16
...
...
@@ -81,7 +81,6 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
assert
logits
.
dim
()
==
1
# batch size 1 for now - could be updated for more but the code would be less clear
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
if
top_k
>
0
:
# Remove all tokens with a probability less than the last token of the top-k
...
...
@@ -98,7 +97,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
sorted_indices_to_remove
[...,
1
:]
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
0
]
=
0
indices_to_remove
=
sorted_indices
[
sorted_indices_to_remove
]
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
dim
=
1
,
index
=
sorted_indices
,
src
=
sorted_indices_to_remove
)
logits
[
indices_to_remove
]
=
filter_value
return
logits
...
...
@@ -122,10 +122,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
inputs
=
{
'input_ids'
:
input_ids
,
'perm_mask'
:
perm_mask
,
'target_mapping'
:
target_mapping
}
outputs
=
model
(
**
inputs
)
# Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
next_token_logits
=
outputs
[
0
][
0
,
-
1
,
:]
/
temperature
next_token_logits
=
outputs
[
0
][
:
,
-
1
,
:]
/
temperature
filtered_logits
=
top_k_top_p_filtering
(
next_token_logits
,
top_k
=
top_k
,
top_p
=
top_p
)
next_token
=
torch
.
multinomial
(
F
.
softmax
(
filtered_logits
,
dim
=-
1
),
num_samples
=
1
)
generated
=
torch
.
cat
((
generated
,
next_token
.
unsqueeze
(
0
)
),
dim
=
1
)
generated
=
torch
.
cat
((
generated
,
next_token
),
dim
=
1
)
return
generated
...
...
@@ -139,6 +139,7 @@ def main():
parser
.
add_argument
(
"--padding_text"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--length"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--num_samples"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--no_cuda"
,
action
=
'store_true'
,
...
...
@@ -176,6 +177,7 @@ def main():
out
=
sample_sequence
(
model
=
model
,
context
=
context_tokens
,
num_samples
=
args
.
num_samples
,
length
=
args
.
length
,
temperature
=
args
.
temperature
,
top_k
=
args
.
top_k
,
...
...
@@ -183,9 +185,10 @@ def main():
device
=
args
.
device
,
is_xlnet
=
bool
(
args
.
model_type
==
"xlnet"
),
)
out
=
out
[
0
,
len
(
context_tokens
):].
tolist
()
text
=
tokenizer
.
decode
(
out
,
clean_up_tokenization_spaces
=
True
)
print
(
text
)
out
=
out
[:,
len
(
context_tokens
):].
tolist
()
for
o
in
out
:
text
=
tokenizer
.
decode
(
o
,
clean_up_tokenization_spaces
=
True
)
print
(
text
)
if
args
.
prompt
:
break
return
text
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment