Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8e5587fb
Commit
8e5587fb
authored
Dec 18, 2019
by
thomwolf
Browse files
few fixes on sampling
parent
641a8dec
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
53 deletions
+42
-53
transformers/modeling_utils.py
transformers/modeling_utils.py
+42
-53
No files found.
transformers/modeling_utils.py
View file @
8e5587fb
...
...
@@ -23,14 +23,12 @@ import json
import
logging
import
os
from
io
import
open
import
warnings
import
six
import
torch
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
functional
as
F
from
tqdm
import
trange
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
...
...
@@ -82,7 +80,6 @@ class PreTrainedModel(nn.Module):
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
# Save config in model
self
.
config
=
config
...
...
@@ -220,9 +217,6 @@ class PreTrainedModel(nn.Module):
# Tie weights if needed
self
.
tie_weights
()
# Initialize decoding head if we have output embeddings
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the base model.
...
...
@@ -569,30 +563,36 @@ class PreTrainedModel(nn.Module):
cur_len
=
input_ids
.
shape
[
1
]
vocab_size
=
self
.
config
.
vocab_size
if
num_return_sequences
!=
1
:
# Expand input to num return sequences
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_return_sequences
,
cur_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_return_sequences
,
cur_len
)
# (batch_size * num_return_sequences, cur_len)
effective_batch_size
=
batch_size
*
num_return_sequences
else
:
effective_batch_size
=
batch_size
if
num_beams
>
1
:
return
self
.
_generate_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
,
num_return_sequences
,
length_penalty
,
num_beams
,
vocab_size
)
return
self
.
_generate_no_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
output
=
self
.
_generate_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
effective_
batch_size
,
length_penalty
,
num_beams
,
vocab_size
)
else
:
output
=
self
.
_generate_no_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
,
num_return_sequences
)
pad_token_id
,
eos_token_ids
,
effective_batch_size
)
if
num_return_sequences
!=
1
:
output
=
output
.
view
(
batch_size
,
num_return_sequences
,
-
1
)
return
output
def
_generate_no_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
,
num_return_sequences
):
""" Generate `num_return_sequences` sequences per batch example without beam search (num_beams == 1).
pad_token_id
,
eos_token_ids
,
batch_size
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
# Expand input to num return sequences
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_return_sequences
,
cur_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_return_sequences
,
cur_len
)
# (batch_size*num_return_sequences, cur_len)
# current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents
=
input_ids
.
new
(
batch_size
*
num_return_sequences
).
fill_
(
1
)
unfinished_sents
=
input_ids
.
new
(
batch_size
).
fill_
(
1
)
# cache compute states
pasts
=
None
...
...
@@ -604,7 +604,7 @@ class PreTrainedModel(nn.Module):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if
repetition_penalty
!=
1.0
:
for
i
in
range
(
batch_size
*
num_return_sequences
):
for
i
in
range
(
batch_size
):
for
previous_tokens
in
set
(
input_ids
[
i
].
tolist
()):
next_token_logits
[
i
,
previous_tokens
]
/=
repetition_penalty
...
...
@@ -635,22 +635,14 @@ class PreTrainedModel(nn.Module):
if
cur_len
==
max_length
:
input_ids
[:,
-
1
].
masked_fill_
(
unfinished_sents
.
to
(
dtype
=
torch
.
bool
),
eos_token_ids
[
0
])
if
num_return_sequences
!=
1
:
input_ids
=
input_ids
.
view
(
batch_size
,
num_return_sequences
,
-
1
)
return
input_ids
def
_generate_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
,
num_return_sequences
,
length_penalty
,
num_beams
,
vocab_size
):
""" Generate `num_return_sequences` sequences per batch example with beam search.
We return the top-`num_return_sequences` beams.
`num_return_sequences` should be bigger than `num_beams` (we default to the min of both)
""" Generate sequences for each example with beam search.
"""
num_return_sequences
=
min
(
num_return_sequences
,
num_beams
)
# Expand input to num beams
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_beams
,
cur_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_beams
,
cur_len
)
# (batch_size * num_beams, cur_len)
...
...
@@ -685,7 +677,7 @@ class PreTrainedModel(nn.Module):
if
temperature
!=
1.0
:
scores
=
scores
/
temperature
# Top-p/top-k filtering
scores
=
top_k_top_p_filtering
(
scores
,
top_k
=
top_k
,
top_p
=
top_p
)
# (batch_size * num_beams, vocab_size)
scores
=
top_k_top_p_filtering
(
scores
,
top_k
=
top_k
,
top_p
=
top_p
,
min_tokens_to_keep
=
2
)
# (batch_size * num_beams, vocab_size)
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
next_words
=
torch
.
multinomial
(
F
.
softmax
(
scores
,
dim
=-
1
),
num_samples
=
2
)
# (batch_size * num_beams, 2)
# Compute next scores
...
...
@@ -778,41 +770,35 @@ class PreTrainedModel(nn.Module):
# print("")
# select the best hypotheses
tgt_len
=
input_ids
.
new
(
batch_size
,
num_return_sequences
)
best
s
=
[]
tgt_len
=
input_ids
.
new
(
batch_size
)
best
=
[]
for
i
,
hypotheses
in
enumerate
(
generated_hyps
):
best_hyps
=
[
hyp
[
1
]
for
hyp
in
sorted
(
hypotheses
.
hyp
,
key
=
lambda
hyp
:
hyp
[
0
])[
-
num_return_sequences
:]]
for
j
,
hyp
in
enumerate
(
best_hyps
):
tgt_len
[
i
,
j
]
=
len
(
hyp
)
+
1
# +1 for the <EOS> symbol
bests
.
append
(
best_hyps
)
best_hyp
=
max
(
hypotheses
.
hyp
,
key
=
lambda
x
:
x
[
0
])[
1
]
tgt_len
[
i
]
=
len
(
best_hyp
)
+
1
# +1 for the <EOS> symbol
best
.
append
(
best_hyp
)
# generate target batch
decoded
=
input_ids
.
new
(
batch_size
,
num_return_sequences
,
tgt_len
.
max
().
item
()).
fill_
(
pad_token_id
)
for
i
,
hyps
in
enumerate
(
bests
):
for
j
,
hypo
in
enumerate
(
hyps
):
decoded
[
i
,
j
,
:
tgt_len
[
i
,
j
]
-
1
]
=
hypo
decoded
[
i
,
j
,
tgt_len
[
i
,
j
]
-
1
]
=
eos_token_ids
[
0
]
if
num_return_sequences
==
1
:
decoded
=
decoded
.
squeeze
(
1
)
# # sanity check
# assert (decoded == eos_token_ids[0]).sum() == 2 * batch_size
decoded
=
input_ids
.
new
(
batch_size
,
tgt_len
.
max
().
item
()).
fill_
(
pad_token_id
)
for
i
,
hypo
in
enumerate
(
best
):
decoded
[
i
,
:
tgt_len
[
i
]
-
1
]
=
hypo
decoded
[
i
,
tgt_len
[
i
]
-
1
]
=
eos_token_ids
[
0
]
return
decoded
def
top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
1.0
,
filter_value
=-
float
(
'Inf'
)):
def
top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
1.0
,
filter_value
=-
float
(
'Inf'
)
,
min_tokens_to_keep
=
1
):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size
x
vocabulary size)
logits: logits distribution shape (batch size
,
vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
if
top_k
>
0
:
top_k
=
min
(
max
(
top_k
,
min_tokens_to_keep
),
logits
.
size
(
-
1
))
# Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
filter_value
...
...
@@ -821,8 +807,11 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
cumulative_probs
=
torch
.
cumsum
(
F
.
softmax
(
sorted_logits
,
dim
=-
1
),
dim
=-
1
)
# Remove tokens with cumulative probability above the threshold
# Remove tokens with cumulative probability above the threshold
(token with 0 are kept)
sorted_indices_to_remove
=
cumulative_probs
>
top_p
if
min_tokens_to_keep
>
1
:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove
[...,
:
min_tokens_to_keep
]
=
0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove
[...,
1
:]
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
0
]
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment