Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b6938916
Commit
b6938916
authored
Dec 17, 2019
by
thomwolf
Browse files
adding beam search
parent
a468870f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
235 additions
and
41 deletions
+235
-41
transformers/configuration_utils.py
transformers/configuration_utils.py
+7
-2
transformers/modeling_utils.py
transformers/modeling_utils.py
+228
-39
No files found.
transformers/configuration_utils.py
View file @
b6938916
...
@@ -62,13 +62,18 @@ class PretrainedConfig(object):
...
@@ -62,13 +62,18 @@ class PretrainedConfig(object):
self
.
is_decoder
=
kwargs
.
pop
(
'is_decoder'
,
False
)
self
.
is_decoder
=
kwargs
.
pop
(
'is_decoder'
,
False
)
# Parameters for sequence generation
# Parameters for sequence generation
self
.
generate_length
=
kwargs
.
pop
(
'generate_length'
,
1
0
)
self
.
generate_
max_
length
=
kwargs
.
pop
(
'generate_
max_
length'
,
2
0
)
self
.
generate_do_sample
=
kwargs
.
pop
(
'generate_do_sample'
,
False
)
self
.
generate_do_sample
=
kwargs
.
pop
(
'generate_do_sample'
,
False
)
self
.
generate_num_beams
=
kwargs
.
pop
(
'generate_num_beams'
,
1
)
self
.
generate_num_beams
=
kwargs
.
pop
(
'generate_num_beams'
,
1
)
self
.
generate_temperature
=
kwargs
.
pop
(
'generate_temperature'
,
1.0
)
self
.
generate_temperature
=
kwargs
.
pop
(
'generate_temperature'
,
1.0
)
self
.
generate_top_k
=
kwargs
.
pop
(
'generate_top_k'
,
50
)
self
.
generate_top_k
=
kwargs
.
pop
(
'generate_top_k'
,
50
)
self
.
generate_top_p
=
kwargs
.
pop
(
'generate_top_p'
,
0
.0
)
self
.
generate_top_p
=
kwargs
.
pop
(
'generate_top_p'
,
1
.0
)
self
.
generate_repetition_penalty
=
kwargs
.
pop
(
'generate_repetition_penalty'
,
1.0
)
self
.
generate_repetition_penalty
=
kwargs
.
pop
(
'generate_repetition_penalty'
,
1.0
)
self
.
generate_bos_token_id
=
kwargs
.
pop
(
'generate_bos_token_id'
,
0
)
self
.
generate_pad_token_id
=
kwargs
.
pop
(
'generate_pad_token_id'
,
0
)
self
.
generate_eos_token_ids
=
kwargs
.
pop
(
'generate_eos_token_ids'
,
0
)
self
.
generate_batch_size
=
kwargs
.
pop
(
'generate_batch_size'
,
1
)
self
.
generate_length_penalty
=
kwargs
.
pop
(
'generate_length_penalty'
,
1.
)
def
save_pretrained
(
self
,
save_directory
):
def
save_pretrained
(
self
,
save_directory
):
""" Save a configuration object to the directory `save_directory`, so that it
""" Save a configuration object to the directory `save_directory`, so that it
...
...
transformers/modeling_utils.py
View file @
b6938916
# coding=utf-8
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright 2018 The Google AI Language Team Authors
, Facebook AI Research authors
and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -488,63 +488,252 @@ class PreTrainedModel(nn.Module):
...
@@ -488,63 +488,252 @@ class PreTrainedModel(nn.Module):
return
model
return
model
def
generate
(
self
,
input_ids
=
None
,
length
=
None
,
do_sample
=
False
,
num_beams
=
None
,
def
prepare_inputs_for_generation
(
self
,
input_ids
,
**
kwargs
):
return
{
"input_ids"
:
input_ids
}
def
generate
(
self
,
input_ids
=
None
,
max_length
=
None
,
do_sample
=
None
,
num_beams
=
None
,
temperature
=
None
,
top_k
=
None
,
top_p
=
None
,
repetition_penalty
=
None
,
temperature
=
None
,
top_k
=
None
,
top_p
=
None
,
repetition_penalty
=
None
,
**
model_kwargs
):
bos_token_id
=
None
,
pad_token_id
=
None
,
eos_token_ids
=
None
,
batch_size
=
None
,
""" Generic sequence generator for single-stack models with a LM head.
length_penalty
=
None
,
**
kwargs
):
""" Sequence generator for models with a LM head.
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
The method currently supports greedy decoding and sampling. See the
Adapted in part from Facebook's XLM beam search code: https://github.com/facebookresearch/XLM
documentation of the `Sampler` class for more information about the
parameters related to sampling.
Params:
Params:
**input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
**input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,)
it as an empty `torch.LongTensor` of shape (1,)
**length**: (`optional`) int
**
max_
length**: (`optional`) int
The length of the sequence to be generated.
The
max
length of the sequence to be generated.
Between 1 and infinity. Default to 20.
**do_sample**: (`optional`) bool
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling.
If set to `False` we use greedy decoding; otherwise sampling. Default to greedy sampling.
**num_beams**: (`optional`) int
Number of beams for beam search. 1 means no beam serach. Default to 1.
**temperature**: (`optional`) float
**temperature**: (`optional`) float
The value used to module the next token probabilities.
The value used to module the next token probabilities.
**k**: (`optional`) int
**
top_
k**: (`optional`) int
The
parameter used
for k-filtering.
The
number of highest probability vocabulary tokens to keep
for
top-
k-filtering.
Between 1 and infinity. Default to 50.
**p**: (`optional`) float
**
top_
p**: (`optional`) float
The
parameter
for nucleus sampling. Must be between 0 and 1.
The
cumulative probability of parameter highest probability vocabulary tokens to keep
for nucleus sampling. Must be between 0 and 1.
Default to 1.
**repetition_penalty**: (`optional`) float
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty.
The parameter for repetition penalty.
Between 1.0 and + infinity. 1.0 means no penalty. Default to 1.
"""
"""
if
input_ids
is
None
:
input_ids
=
torch
.
tensor
([[]],
dtype
=
torch
.
long
,
device
=
next
(
self
.
parameters
()).
device
)
# We cannot generate if the model does not have a LM head
# We cannot generate if the model does not have a LM head
if
self
.
get_output_embeddings
()
is
None
:
if
self
.
get_output_embeddings
()
is
None
:
raise
AttributeError
(
"You tried do generated sequences with a model that does not have a LM Head."
)
raise
AttributeError
(
"You tried do generated sequences with a model that does not have a LM Head."
)
sampler_config
=
{
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
generate_max_length
"k"
:
k
,
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
generate_do_sample
"p"
:
p
,
num_beams
=
num_beams
if
num_beams
is
not
None
else
self
.
config
.
generate_num_beams
"do_sample"
:
do_sample
,
temperature
=
temperature
if
temperature
is
not
None
else
self
.
config
.
generate_temperature
"temperature"
:
temperature
,
top_k
=
top_k
if
top_k
is
not
None
else
self
.
config
.
generate_top_k
"repetition_penalty"
:
repetition_penalty
,
top_p
=
top_p
if
top_p
is
not
None
else
self
.
config
.
generate_top_p
}
repetition_penalty
=
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
config
.
generate_repetition_penalty
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
generate_bos_token_id
sampler
=
Sampler
(
**
sampler_config
)
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
generate_pad_token_id
generated_sequence
=
input_ids
eos_token_ids
=
eos_token_ids
if
eos_token_ids
is
not
None
else
self
.
config
.
generate_eos_token_ids
for
_
in
trange
(
length
):
batch_size
=
batch_size
if
batch_size
is
not
None
else
self
.
config
.
generate_batch_size
arguments
=
self
.
_prepare_inputs_for_decoding
(
generated_sequence
,
**
model_kwargs
)
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
generate_length_penalty
outputs
=
self
(
**
arguments
)
next_tokens_logits
=
outputs
[
0
][:,
-
1
,
:]
if
input_ids
is
not
None
:
next_tokens
=
sampler
.
get_one_token
(
batch_size
=
input_ids
.
shape
[
0
]
# overriden by the input batch_size
next_tokens_logits
,
generated_sequence
if
isinstance
(
eos_token_ids
,
int
):
)
eos_token_ids
=
[
eos_token_ids
]
generated_sequence
=
torch
.
cat
((
generated_sequence
,
next_tokens
),
dim
=
1
)
assert
isinstance
(
max_length
,
int
)
and
0
<
max_length
,
"`max_length` should be a strictely positive integer."
assert
isinstance
(
do_sample
,
bool
),
"`do_sample` should be a boolean."
assert
isinstance
(
num_beams
,
int
)
and
0
<
num_beams
,
"`num_beams` should be a strictely positive integer."
assert
0
<
temperature
,
"`temperature` should be positive."
assert
isinstance
(
top_k
,
int
)
and
0
<
top_k
,
"`top_k` should be a strictely positive integer."
assert
0
<=
top_p
<=
1
,
"`top_p` should be between 0 and 1."
assert
0
<
repetition_penalty
,
"`repetition_penalty` should be strictely positive."
assert
isinstance
(
bos_token_id
,
int
)
and
0
<=
bos_token_id
,
"`bos_token_id` should be a positive integer."
assert
isinstance
(
pad_token_id
,
int
)
and
0
<=
pad_token_id
,
"`pad_token_id` should be a positive integer."
assert
isinstance
(
eos_token_ids
,
(
list
,
tuple
))
and
(
0
<=
e
for
e
in
eos_token_ids
),
\
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert
isinstance
(
batch_size
,
int
)
and
0
<
batch_size
,
"`batch_size` should be a strictely positive integer."
assert
0
<
length_penalty
,
"`length_penalty` should be strictely positive."
if
input_ids
is
None
:
input_ids
=
torch
.
full
((
batch_size
,
1
),
bos_token_id
,
dtype
=
torch
.
long
,
device
=
next
(
self
.
parameters
()).
device
)
else
:
assert
input_ids
.
dims
()
==
2
# current position and vocab size
cur_len
=
1
vocab_size
=
self
.
config
.
vocab_size
# Expand input to num beams
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_beams
,
cur_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_beams
,
cur_len
)
# (batch_size * num_beams, cur_len)
# generated hypotheses
generated_hyps
=
[
BeamHypotheses
(
num_beams
,
max_length
,
length_penalty
,
early_stopping
=
False
)
for
_
in
range
(
batch_size
)]
# scores for each sentence in the beam
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
beam_scores
[:,
1
:]
=
-
1e9
beam_scores
=
beam_scores
.
view
(
-
1
)
# cache compute states
pasts
=
None
# self.prepare_pasts()
# done sentences
done
=
[
False
for
_
in
range
(
batch_size
)]
while
cur_len
<
max_length
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
pasts
=
pasts
)
scores
=
self
(
**
model_inputs
)[
0
]
# (batch_size * num_beams, cur_len, vocab_size)
scores
=
scores
[:,
-
1
,
:]
# (batch_size * num_beams, vocab_size)
scores
=
F
.
log_softmax
(
scores
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
assert
scores
.
size
()
==
(
batch_size
*
num_beams
,
vocab_size
)
# select next words with scores
_scores
=
scores
+
beam_scores
[:,
None
].
expand_as
(
scores
)
# (batch_size * num_beams, vocab_size)
_scores
=
_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
# (batch_size, num_beams * vocab_size)
next_scores
,
next_words
=
torch
.
topk
(
_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
assert
next_scores
.
size
()
==
next_words
.
size
()
==
(
batch_size
,
2
*
num_beams
)
# next batch beam content
# list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
next_batch_beam
=
[]
# for each sentence
for
sent_id
in
range
(
batch_size
):
# if we are done with this sentence
done
[
sent_id
]
=
done
[
sent_id
]
or
generated_hyps
[
sent_id
].
is_done
(
next_scores
[
sent_id
].
max
().
item
())
if
done
[
sent_id
]:
next_batch_beam
.
extend
([(
0
,
pad_token_id
,
0
)]
*
num_beams
)
# pad the batch
continue
# next sentence beam content
next_sent_beam
=
[]
# next words for this sentence
for
idx
,
value
in
zip
(
next_words
[
sent_id
],
next_scores
[
sent_id
]):
# get beam and word IDs
beam_id
=
idx
//
vocab_size
word_id
=
idx
%
vocab_size
# end of sentence, or next word
if
word_id
.
item
()
in
eos_token_ids
or
cur_len
+
1
==
max_length
:
generated_hyps
[
sent_id
].
add
(
input_ids
[
sent_id
*
num_beams
+
beam_id
,
:
cur_len
].
clone
(),
value
.
item
())
else
:
next_sent_beam
.
append
((
value
,
word_id
,
sent_id
*
num_beams
+
beam_id
))
# the beam for next step is full
if
len
(
next_sent_beam
)
==
num_beams
:
break
# update next beam content
assert
len
(
next_sent_beam
)
==
0
if
cur_len
+
1
==
max_length
else
num_beams
if
len
(
next_sent_beam
)
==
0
:
next_sent_beam
=
[(
0
,
pad_token_id
,
0
)]
*
num_beams
# pad the batch
next_batch_beam
.
extend
(
next_sent_beam
)
assert
len
(
next_batch_beam
)
==
num_beams
*
(
sent_id
+
1
)
# sanity check / prepare next batch
assert
len
(
next_batch_beam
)
==
batch_size
*
num_beams
beam_scores
=
beam_scores
.
new
([
x
[
0
]
for
x
in
next_batch_beam
])
beam_words
=
input_ids
.
new
([
x
[
1
]
for
x
in
next_batch_beam
])
beam_idx
=
input_ids
.
new
([
x
[
2
]
for
x
in
next_batch_beam
])
# re-order batch and internal states
input_ids
=
input_ids
[
beam_idx
,
:]
input_ids
=
torch
.
cat
([
input_ids
,
beam_words
.
unsqueeze
(
1
)],
dim
=-
1
)
# TODO: Activate cache
# for k in cache.keys():
# if k != 'slen':
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
# update current length
cur_len
=
cur_len
+
1
# stop when we are done with each sentence
if
all
(
done
):
break
# visualize hypotheses
# print([len(x) for x in generated_hyps], cur_len)
# globals().update( locals() );
# !import code; code.interact(local=vars())
# for ii in range(batch_size):
# for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
# print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
# print("")
# select the best hypotheses
tgt_len
=
src_len
.
new
(
batch_size
)
best
=
[]
for
i
,
hypotheses
in
enumerate
(
generated_hyps
):
best_hyp
=
max
(
hypotheses
.
hyp
,
key
=
lambda
x
:
x
[
0
])[
1
]
tgt_len
[
i
]
=
len
(
best_hyp
)
+
1
# +1 for the <EOS> symbol
best
.
append
(
best_hyp
)
# generate target batch
decoded
=
src_len
.
new
(
tgt_len
.
max
().
item
(),
batch_size
).
fill_
(
self
.
pad_index
)
for
i
,
hypo
in
enumerate
(
best
):
decoded
[:
tgt_len
[
i
]
-
1
,
i
]
=
hypo
decoded
[
tgt_len
[
i
]
-
1
,
i
]
=
self
.
eos_index
# sanity check
assert
(
decoded
==
self
.
eos_index
).
sum
()
==
2
*
batch_size
return
decoded
,
tgt_len
class
BeamHypotheses
(
object
):
def
__init__
(
self
,
n_hyp
,
max_length
,
length_penalty
,
early_stopping
):
"""
Initialize n-best list of hypotheses.
"""
self
.
max_length
=
max_length
-
1
# ignoring bos_token
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
self
.
n_hyp
=
n_hyp
self
.
hyp
=
[]
self
.
worst_score
=
1e9
def
__len__
(
self
):
"""
Number of hypotheses in the list.
"""
return
len
(
self
.
hyp
)
return
generated_sequence
.
squeeze
(
0
)
def
add
(
self
,
hyp
,
sum_logprobs
):
"""
Add a new hypothesis to the list.
"""
score
=
sum_logprobs
/
len
(
hyp
)
**
self
.
length_penalty
if
len
(
self
)
<
self
.
n_hyp
or
score
>
self
.
worst_score
:
self
.
hyp
.
append
((
score
,
hyp
))
if
len
(
self
)
>
self
.
n_hyp
:
sorted_scores
=
sorted
([(
s
,
idx
)
for
idx
,
(
s
,
_
)
in
enumerate
(
self
.
hyp
)])
del
self
.
hyp
[
sorted_scores
[
0
][
1
]]
self
.
worst_score
=
sorted_scores
[
1
][
0
]
else
:
self
.
worst_score
=
min
(
score
,
self
.
worst_score
)
def
_prepare_inputs_for_decoding
(
self
,
input_ids
,
**
model_kwargs
):
def
is_done
(
self
,
best_sum_logprobs
):
return
model_kwargs
.
update
({
"input_ids"
:
input_ids
})
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if
len
(
self
)
<
self
.
n_hyp
:
return
False
elif
self
.
early_stopping
:
return
True
else
:
return
self
.
worst_score
>=
best_sum_logprobs
/
self
.
max_length
**
self
.
length_penalty
class
Sampler
(
object
):
class
Sampler
(
object
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment