Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bbc0c86f
Commit
bbc0c86f
authored
Dec 17, 2019
by
thomwolf
Browse files
beam search + single beam decoding
parent
b6938916
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
123 additions
and
29 deletions
+123
-29
transformers/modeling_utils.py
transformers/modeling_utils.py
+123
-29
No files found.
transformers/modeling_utils.py
View file @
bbc0c86f
...
@@ -544,29 +544,90 @@ class PreTrainedModel(nn.Module):
...
@@ -544,29 +544,90 @@ class PreTrainedModel(nn.Module):
if
isinstance
(
eos_token_ids
,
int
):
if
isinstance
(
eos_token_ids
,
int
):
eos_token_ids
=
[
eos_token_ids
]
eos_token_ids
=
[
eos_token_ids
]
assert
isinstance
(
max_length
,
int
)
and
0
<
max_length
,
"`max_length` should be a strictely positive integer."
assert
isinstance
(
max_length
,
int
)
and
max_length
>
0
,
"`max_length` should be a strictely positive integer."
assert
isinstance
(
do_sample
,
bool
),
"`do_sample` should be a boolean."
assert
isinstance
(
do_sample
,
bool
),
"`do_sample` should be a boolean."
assert
isinstance
(
num_beams
,
int
)
and
0
<
num_beams
,
"`num_beams` should be a strictely positive integer."
assert
isinstance
(
num_beams
,
int
)
and
num_beams
>
0
,
"`num_beams` should be a strictely positive integer."
assert
0
<
temperature
,
"`temperature` should be positive."
assert
temperature
>
0
,
"`temperature` should be positive."
assert
isinstance
(
top_k
,
int
)
and
0
<
top_k
,
"`top_k` should be a strictely positive integer."
assert
isinstance
(
top_k
,
int
)
and
top_k
>
0
,
"`top_k` should be a strictely positive integer."
assert
0
<=
top_p
<=
1
,
"`top_p` should be between 0 and 1."
assert
0
<=
top_p
<=
1
,
"`top_p` should be between 0 and 1."
assert
0
<
repetition_penalty
,
"`repetition_penalty` should be
strictely positive
."
assert
repetition_penalty
>=
1.0
,
"`repetition_penalty` should be
>= 1
."
assert
isinstance
(
bos_token_id
,
int
)
and
0
<=
bos_token_id
,
"`bos_token_id` should be a positive integer."
assert
isinstance
(
bos_token_id
,
int
)
and
bos_token_id
>=
0
,
"`bos_token_id` should be a positive integer."
assert
isinstance
(
pad_token_id
,
int
)
and
0
<=
pad_token_id
,
"`pad_token_id` should be a positive integer."
assert
isinstance
(
pad_token_id
,
int
)
and
pad_token_id
>=
0
,
"`pad_token_id` should be a positive integer."
assert
isinstance
(
eos_token_ids
,
(
list
,
tuple
))
and
(
0
<
=
e
for
e
in
eos_token_ids
),
\
assert
isinstance
(
eos_token_ids
,
(
list
,
tuple
))
and
(
e
>
=
0
for
e
in
eos_token_ids
),
\
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert
isinstance
(
batch_size
,
int
)
and
0
<
batch_size
,
"`batch_size` should be a strictely positive integer."
assert
isinstance
(
batch_size
,
int
)
and
batch_size
>
0
,
"`batch_size` should be a strictely positive integer."
assert
0
<
length_penalty
,
"`length_penalty` should be strictely positive."
assert
length_penalty
>
0
,
"`length_penalty` should be strictely positive."
if
input_ids
is
None
:
if
input_ids
is
None
:
input_ids
=
torch
.
full
((
batch_size
,
1
),
bos_token_id
,
dtype
=
torch
.
long
,
device
=
next
(
self
.
parameters
()).
device
)
input_ids
=
torch
.
full
((
batch_size
,
1
),
bos_token_id
,
dtype
=
torch
.
long
,
device
=
next
(
self
.
parameters
()).
device
)
else
:
else
:
assert
input_ids
.
dims
()
==
2
assert
input_ids
.
dims
()
==
2
,
"Input prompt should be of shape (batch_size, sequence length)."
# current position and vocab size
# current position and vocab size
cur_len
=
1
cur_len
=
input_ids
.
shape
[
1
]
vocab_size
=
self
.
config
.
vocab_size
vocab_size
=
self
.
config
.
vocab_size
if
num_beams
>
1
:
return
self
.
_generate_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
length_penalty
,
num_beams
,
pad_token_id
,
eos_token_ids
,
vocab_size
,
batch_size
)
return
self
.
_generate_no_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
)
def
_generate_no_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
):
""" Generate a sentence without beam search (num_beams == 1). """
# current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents
=
input_ids
.
new
(
batch_size
).
fill_
(
1
)
# cache compute states
pasts
=
None
while
cur_len
<
max_length
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
pasts
=
pasts
)
outputs
=
self
(
**
model_inputs
)
next_token_logits
=
outputs
[
0
][:,
-
1
,
:]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if
repetition_penalty
!=
1.0
:
for
i
in
range
(
batch_size
):
for
_
in
set
(
input_ids
[
i
].
tolist
()):
next_token_logits
[
i
,
_
]
/=
repetition_penalty
if
do_sample
:
# Temperature (higher temperature => more likely to sample low probability tokens)
if
temperature
!=
1.0
:
next_token_logits
=
next_token_logits
/
temperature
# Top-p/top-k filtering
next_token_logits
=
top_k_top_p_filtering
(
next_token_logits
,
top_k
=
top_k
,
top_p
=
top_p
)
# Sample
next_token
=
torch
.
multinomial
(
F
.
softmax
(
next_token_logits
,
dim
=-
1
),
num_samples
=
1
)
else
:
# Greedy decoding
next_token
=
torch
.
argmax
(
next_token_logits
,
dim
=-
1
).
unsqueeze
(
-
1
)
# update generations and finished sentences
tokens_to_add
=
next_token
*
unfinished_sents
+
pad_token_id
*
(
1
-
unfinished_sents
)
input_ids
=
torch
.
cat
([
input_ids
,
tokens_to_add
],
dim
=-
1
)
for
eos_token_id
in
eos_token_ids
:
unfinished_sents
.
mul_
(
tokens_to_add
.
squeeze
(
-
1
).
ne
(
eos_token_id
).
long
())
cur_len
=
cur_len
+
1
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if
unfinished_sents
.
max
()
==
0
:
break
# add eos_token_ids to unfinished sentences
if
cur_len
==
max_length
:
input_ids
[:,
-
1
].
masked_fill_
(
unfinished_sents
.
byte
(),
eos_token_ids
[
0
])
return
input_ids
def
_generate_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
length_penalty
,
num_beams
,
pad_token_id
,
eos_token_ids
,
vocab_size
,
batch_size
):
""" Generate a sentence with beam search. """
# Expand input to num beams
# Expand input to num beams
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_beams
,
cur_len
)
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_beams
,
cur_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_beams
,
cur_len
)
# (batch_size * num_beams, cur_len)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_beams
,
cur_len
)
# (batch_size * num_beams, cur_len)
...
@@ -592,9 +653,11 @@ class PreTrainedModel(nn.Module):
...
@@ -592,9 +653,11 @@ class PreTrainedModel(nn.Module):
scores
=
F
.
log_softmax
(
scores
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
scores
=
F
.
log_softmax
(
scores
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
assert
scores
.
size
()
==
(
batch_size
*
num_beams
,
vocab_size
)
assert
scores
.
size
()
==
(
batch_size
*
num_beams
,
vocab_size
)
# select next words with scores
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
_scores
=
scores
+
beam_scores
[:,
None
].
expand_as
(
scores
)
# (batch_size * num_beams, vocab_size)
_scores
=
scores
+
beam_scores
[:,
None
].
expand_as
(
scores
)
# (batch_size * num_beams, vocab_size)
_scores
=
_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
# (batch_size, num_beams * vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
_scores
=
_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
# (batch_size, num_beams * vocab_size)
next_scores
,
next_words
=
torch
.
topk
(
_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
next_scores
,
next_words
=
torch
.
topk
(
_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
assert
next_scores
.
size
()
==
next_words
.
size
()
==
(
batch_size
,
2
*
num_beams
)
assert
next_scores
.
size
()
==
next_words
.
size
()
==
(
batch_size
,
2
*
num_beams
)
...
@@ -604,11 +667,11 @@ class PreTrainedModel(nn.Module):
...
@@ -604,11 +667,11 @@ class PreTrainedModel(nn.Module):
next_batch_beam
=
[]
next_batch_beam
=
[]
# for each sentence
# for each sentence
for
sent_id
in
range
(
batch_size
):
for
batch_ex
in
range
(
batch_size
):
# if we are done with this sentence
# if we are done with this sentence
done
[
sent_id
]
=
done
[
sent_id
]
or
generated_hyps
[
sent_id
].
is_done
(
next_scores
[
sent_id
].
max
().
item
())
done
[
batch_ex
]
=
done
[
batch_ex
]
or
generated_hyps
[
batch_ex
].
is_done
(
next_scores
[
batch_ex
].
max
().
item
())
if
done
[
sent_id
]:
if
done
[
batch_ex
]:
next_batch_beam
.
extend
([(
0
,
pad_token_id
,
0
)]
*
num_beams
)
# pad the batch
next_batch_beam
.
extend
([(
0
,
pad_token_id
,
0
)]
*
num_beams
)
# pad the batch
continue
continue
...
@@ -616,7 +679,7 @@ class PreTrainedModel(nn.Module):
...
@@ -616,7 +679,7 @@ class PreTrainedModel(nn.Module):
next_sent_beam
=
[]
next_sent_beam
=
[]
# next words for this sentence
# next words for this sentence
for
idx
,
valu
e
in
zip
(
next_words
[
sent_id
],
next_scores
[
sent_id
]):
for
idx
,
scor
e
in
zip
(
next_words
[
batch_ex
],
next_scores
[
batch_ex
]):
# get beam and word IDs
# get beam and word IDs
beam_id
=
idx
//
vocab_size
beam_id
=
idx
//
vocab_size
...
@@ -624,9 +687,9 @@ class PreTrainedModel(nn.Module):
...
@@ -624,9 +687,9 @@ class PreTrainedModel(nn.Module):
# end of sentence, or next word
# end of sentence, or next word
if
word_id
.
item
()
in
eos_token_ids
or
cur_len
+
1
==
max_length
:
if
word_id
.
item
()
in
eos_token_ids
or
cur_len
+
1
==
max_length
:
generated_hyps
[
sent_id
].
add
(
input_ids
[
sent_id
*
num_beams
+
beam_id
,
:
cur_len
].
clone
(),
valu
e
.
item
())
generated_hyps
[
batch_ex
].
add
(
input_ids
[
batch_ex
*
num_beams
+
beam_id
,
:
cur_len
].
clone
(),
scor
e
.
item
())
else
:
else
:
next_sent_beam
.
append
((
valu
e
,
word_id
,
sent_id
*
num_beams
+
beam_id
))
next_sent_beam
.
append
((
scor
e
,
word_id
,
batch_ex
*
num_beams
+
beam_id
))
# the beam for next step is full
# the beam for next step is full
if
len
(
next_sent_beam
)
==
num_beams
:
if
len
(
next_sent_beam
)
==
num_beams
:
...
@@ -637,7 +700,7 @@ class PreTrainedModel(nn.Module):
...
@@ -637,7 +700,7 @@ class PreTrainedModel(nn.Module):
if
len
(
next_sent_beam
)
==
0
:
if
len
(
next_sent_beam
)
==
0
:
next_sent_beam
=
[(
0
,
pad_token_id
,
0
)]
*
num_beams
# pad the batch
next_sent_beam
=
[(
0
,
pad_token_id
,
0
)]
*
num_beams
# pad the batch
next_batch_beam
.
extend
(
next_sent_beam
)
next_batch_beam
.
extend
(
next_sent_beam
)
assert
len
(
next_batch_beam
)
==
num_beams
*
(
sent_id
+
1
)
assert
len
(
next_batch_beam
)
==
num_beams
*
(
batch_ex
+
1
)
# sanity check / prepare next batch
# sanity check / prepare next batch
assert
len
(
next_batch_beam
)
==
batch_size
*
num_beams
assert
len
(
next_batch_beam
)
==
batch_size
*
num_beams
...
@@ -670,7 +733,7 @@ class PreTrainedModel(nn.Module):
...
@@ -670,7 +733,7 @@ class PreTrainedModel(nn.Module):
# print("")
# print("")
# select the best hypotheses
# select the best hypotheses
tgt_len
=
src_len
.
new
(
batch_size
)
tgt_len
=
input_ids
.
new
(
batch_size
)
best
=
[]
best
=
[]
for
i
,
hypotheses
in
enumerate
(
generated_hyps
):
for
i
,
hypotheses
in
enumerate
(
generated_hyps
):
...
@@ -679,15 +742,46 @@ class PreTrainedModel(nn.Module):
...
@@ -679,15 +742,46 @@ class PreTrainedModel(nn.Module):
best
.
append
(
best_hyp
)
best
.
append
(
best_hyp
)
# generate target batch
# generate target batch
decoded
=
src_len
.
new
(
tgt_len
.
max
().
item
(),
batch_size
).
fill_
(
self
.
pad_index
)
decoded
=
input_ids
.
new
(
batch_size
,
tgt_len
.
max
().
item
()).
fill_
(
pad_token_id
)
for
i
,
hypo
in
enumerate
(
best
):
for
i
,
hypo
in
enumerate
(
best
):
decoded
[:
tgt_len
[
i
]
-
1
,
i
]
=
hypo
decoded
[
i
,
:
tgt_len
[
i
]
-
1
]
=
hypo
decoded
[
tgt_len
[
i
]
-
1
,
i
]
=
self
.
eos_index
decoded
[
i
,
tgt_len
[
i
]
-
1
]
=
eos_token_ids
[
0
]
# sanity check
#
#
sanity check
assert
(
decoded
==
self
.
eos_index
).
sum
()
==
2
*
batch_size
#
assert (decoded ==
eos_token_ids[0]
).sum() == 2 * batch_size
return
decoded
,
tgt_len
return
decoded
def
top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
0.0
,
filter_value
=-
float
(
'Inf'
)):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size x vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
if
top_k
>
0
:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
filter_value
if
top_p
>
0.0
:
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
cumulative_probs
=
torch
.
cumsum
(
F
.
softmax
(
sorted_logits
,
dim
=-
1
),
dim
=-
1
)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove
=
cumulative_probs
>
top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove
[...,
1
:]
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
0
]
=
0
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
dim
=
1
,
index
=
sorted_indices
,
src
=
sorted_indices_to_remove
)
logits
[
indices_to_remove
]
=
filter_value
return
logits
class
BeamHypotheses
(
object
):
class
BeamHypotheses
(
object
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment