Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c47394b0
Commit
c47394b0
authored
Mar 05, 2020
by
Patrick von Platen
Browse files
refactoring and bug fixing beam search generate
parent
ff9e79ba
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
30 deletions
+48
-30
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+48
-30
No files found.
src/transformers/modeling_utils.py
View file @
c47394b0
...
...
@@ -15,11 +15,11 @@
# limitations under the License.
"""PyTorch BERT model."""
import
logging
import
os
import
typing
import
ipdb
import
torch
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
...
...
@@ -758,6 +758,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else
:
assert
input_ids
.
dim
()
==
2
,
"Input prompt should be of shape (batch_size, sequence length)."
# not allow to duplicate outputs when greedy decoding
if
do_sample
is
False
:
if
num_beams
==
1
:
# no_beam_search greedy generation conditions
...
...
@@ -781,15 +782,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
cur_len
=
input_ids
.
shape
[
1
]
vocab_size
=
self
.
config
.
vocab_size
if
num_return_sequences
!=
1
and
do_sample
:
# Expand input to num return sequences
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_return_sequences
,
cur_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_return_sequences
,
cur_len
)
# shape: (batch_size * num_return_sequences, cur_len)
# set effective batch size and effective batch multiplier according to do_sample
if
do_sample
:
effective_batch_size
=
batch_size
*
num_return_sequences
effective_batch_mult
=
num_return_sequences
else
:
effective_batch_size
=
batch_size
effective_batch_mult
=
1
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if
num_return_sequences
>
1
or
num_beams
>
1
:
input_ids_len
=
input_ids
.
shape
[
-
1
]
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
effective_batch_mult
*
num_beams
,
input_ids_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
effective_batch_size
*
num_beams
,
input_ids_len
)
# shape: (batch_size * num_return_sequences * num_beams, cur_len)
if
num_beams
>
1
:
output
=
self
.
_generate_beam_search
(
...
...
@@ -892,12 +899,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# unfinished_sents is set to zero if eos in sentence
unfinished_sents
.
mul_
((
~
eos_in_sents
).
long
())
cur_len
=
cur_len
+
1
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if
unfinished_sents
.
max
()
==
0
:
break
cur_len
=
cur_len
+
1
# if there are different sentences lengths in the batch, some batches have to be padded
if
sent_lengths
.
min
().
item
()
!=
sent_lengths
.
max
().
item
():
assert
pad_token_id
is
not
None
,
"`Pad_token_id` has to be defined if batches have different lengths"
...
...
@@ -932,10 +939,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_beams
,
cur_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_beams
,
cur_len
)
# (batch_size * num_beams, cur_len)
# generated hypotheses
generated_hyps
=
[
BeamHypotheses
(
num_beams
,
max_length
,
length_penalty
,
early_stopping
=
False
)
for
_
in
range
(
batch_size
)
...
...
@@ -945,8 +948,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
# Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if
do_sample
is
False
:
beam_scores
[:,
1
:]
=
-
1e9
#
if do_sample is False:
beam_scores
[:,
1
:]
=
-
1e9
beam_scores
=
beam_scores
.
view
(
-
1
)
# shape (batch_size * num_beams,)
# cache compute states
...
...
@@ -996,6 +999,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# Compute next scores
next_scores
=
torch
.
gather
(
_scores
,
-
1
,
next_tokens
)
# (batch_size, num_beams * 2)
# sort the sampled vector to make sure that the first num_beams samples are the best
next_scores
,
next_scores_indices
=
torch
.
sort
(
next_scores
,
descending
=
True
,
dim
=
1
)
next_tokens
=
torch
.
gather
(
next_tokens
,
-
1
,
next_scores_indices
)
# (batch_size, num_beams * 2)
else
:
# do greedy beam search
scores
=
F
.
log_softmax
(
next_token_logits
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
...
...
@@ -1006,6 +1012,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_scores
=
next_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
# (batch_size, num_beams * vocab_size)
next_scores
,
next_tokens
=
torch
.
topk
(
next_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
assert
next_scores
.
size
()
==
next_tokens
.
size
()
==
(
batch_size
,
2
*
num_beams
)
...
...
@@ -1041,14 +1048,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
beam_id
=
idx
//
vocab_size
token_id
=
idx
%
vocab_size
# add to generated hypotheses if end of sentence or last iteration
effective_beam_id
=
batch_idx
*
num_beams
+
beam_id
# add to generated hypotheses if end of sentence
if
eos_token_ids
is
not
None
and
token_id
.
item
()
in
eos_token_ids
:
generated_hyps
[
batch_idx
].
add
(
input_ids
[
batch_idx
*
num_beams
+
beam_id
,
:
cur_len
].
clone
(),
score
.
item
(),
input_ids
[
effective_beam_id
].
clone
(),
score
.
item
(),
)
else
:
# add next predicted word if it is not eos_token
next_sent_beam
.
append
((
score
,
token_id
,
batch_idx
*
num_beams
+
beam_id
))
next_sent_beam
.
append
((
score
,
token_id
,
effective_
beam_id
))
# the beam for next step is full
if
len
(
next_sent_beam
)
==
num_beams
:
...
...
@@ -1073,24 +1081,34 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if
past
:
past
=
self
.
_reorder_cache
(
past
,
beam_idx
)
# update current length
cur_len
=
cur_len
+
1
# stop when we are done with each sentence
if
all
(
done
):
break
# update current length
cur_len
=
cur_len
+
1
# finalize all open beam hypotheses and end to generated hypotheses
for
batch_idx
in
range
(
batch_size
):
# Add all open beam hypothesis to generated_hyps
if
not
done
[
batch_idx
]:
for
idx
,
score
in
zip
(
next_tokens
[
batch_idx
],
next_scores
[
batch_idx
]):
if
done
[
batch_idx
]:
continue
# get beam and word IDs
beam_id
=
idx
//
vocab_size
token_id
=
idx
%
vocab_size
generated_hyps
[
batch_idx
].
add
(
input_ids
[
batch_idx
*
num_beams
+
beam_id
,
:
cur_len
].
clone
(),
score
.
item
()
)
# test that beam scores match previously calculated scores if not eos and batch_idx not done
if
eos_token_ids
is
not
None
and
all
(
(
token_id
%
vocab_size
).
item
()
not
in
eos_token_ids
for
token_id
in
next_tokens
[
batch_idx
]
):
assert
torch
.
all
(
next_scores
[
batch_idx
,
:
num_beams
]
==
beam_scores
.
view
(
batch_size
,
num_beams
)[
batch_idx
]
),
"If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}"
.
format
(
next_scores
[:,
:
num_beams
][
batch_idx
],
beam_scores
.
view
(
batch_size
,
num_beams
)[
batch_idx
]
)
# need to add best num_beams hypotheses to generated hyps
for
beam_id
in
range
(
num_beams
):
effective_beam_id
=
batch_idx
*
num_beams
+
beam_id
final_score
=
beam_scores
[
effective_beam_id
].
item
()
final_tokens
=
input_ids
[
effective_beam_id
]
generated_hyps
[
batch_idx
].
add
(
final_tokens
,
final_score
)
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size
=
batch_size
if
do_sample
else
batch_size
*
num_return_sequences
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment