Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
641a8dec
Commit
641a8dec
authored
Dec 18, 2019
by
thomwolf
Browse files
clean up code and add arbitrary number of return sequences
parent
77d39720
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
120 additions
and
378 deletions
+120
-378
transformers/configuration_utils.py
transformers/configuration_utils.py
+13
-12
transformers/modeling_encoder_decoder.py
transformers/modeling_encoder_decoder.py
+0
-95
transformers/modeling_utils.py
transformers/modeling_utils.py
+107
-58
transformers/tests/sampling_test.py
transformers/tests/sampling_test.py
+0
-213
No files found.
transformers/configuration_utils.py
View file @
641a8dec
...
...
@@ -62,18 +62,19 @@ class PretrainedConfig(object):
self
.
is_decoder
=
kwargs
.
pop
(
'is_decoder'
,
False
)
# Parameters for sequence generation
self
.
generate_max_length
=
kwargs
.
pop
(
'generate_max_length'
,
20
)
self
.
generate_do_sample
=
kwargs
.
pop
(
'generate_do_sample'
,
False
)
self
.
generate_num_beams
=
kwargs
.
pop
(
'generate_num_beams'
,
1
)
self
.
generate_temperature
=
kwargs
.
pop
(
'generate_temperature'
,
1.0
)
self
.
generate_top_k
=
kwargs
.
pop
(
'generate_top_k'
,
50
)
self
.
generate_top_p
=
kwargs
.
pop
(
'generate_top_p'
,
1.0
)
self
.
generate_repetition_penalty
=
kwargs
.
pop
(
'generate_repetition_penalty'
,
1.0
)
self
.
generate_bos_token_id
=
kwargs
.
pop
(
'generate_bos_token_id'
,
0
)
self
.
generate_pad_token_id
=
kwargs
.
pop
(
'generate_pad_token_id'
,
0
)
self
.
generate_eos_token_ids
=
kwargs
.
pop
(
'generate_eos_token_ids'
,
0
)
self
.
generate_batch_size
=
kwargs
.
pop
(
'generate_batch_size'
,
1
)
self
.
generate_length_penalty
=
kwargs
.
pop
(
'generate_length_penalty'
,
1.
)
self
.
max_length
=
kwargs
.
pop
(
'max_length'
,
20
)
self
.
do_sample
=
kwargs
.
pop
(
'do_sample'
,
False
)
self
.
num_beams
=
kwargs
.
pop
(
'num_beams'
,
1
)
self
.
temperature
=
kwargs
.
pop
(
'temperature'
,
1.0
)
self
.
top_k
=
kwargs
.
pop
(
'top_k'
,
50
)
self
.
top_p
=
kwargs
.
pop
(
'top_p'
,
1.0
)
self
.
repetition_penalty
=
kwargs
.
pop
(
'repetition_penalty'
,
1.0
)
self
.
bos_token_id
=
kwargs
.
pop
(
'bos_token_id'
,
0
)
self
.
pad_token_id
=
kwargs
.
pop
(
'pad_token_id'
,
0
)
self
.
eos_token_ids
=
kwargs
.
pop
(
'eos_token_ids'
,
0
)
self
.
batch_size
=
kwargs
.
pop
(
'batch_size'
,
1
)
self
.
length_penalty
=
kwargs
.
pop
(
'length_penalty'
,
1.
)
self
.
num_return_sequences
=
kwargs
.
pop
(
'num_return_sequences'
,
1
)
def
save_pretrained
(
self
,
save_directory
):
""" Save a configuration object to the directory `save_directory`, so that it
...
...
transformers/modeling_encoder_decoder.py
View file @
641a8dec
...
...
@@ -25,7 +25,6 @@ from torch import nn
from
tqdm
import
trange
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_utils
import
Sampler
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -203,100 +202,6 @@ class PreTrainedEncoderDecoder(nn.Module):
return
decoder_outputs
+
encoder_outputs
def
decode
(
self
,
encoder_input_ids
,
decoder_prompt_ids
=
None
,
device
=
torch
.
device
(
"cpu"
),
length
=
10
,
do_sample
=
False
,
temperature
=
1.0
,
k
=
9
,
p
=
0.
,
repetition_penalty
=
1.
,
**
kwargs
):
""" Generic sequence generator for encoder-decoder models.
For encoder-decoders the generation consists in:
- Performing a forward pass through the encoder once;
- Pass the encoder's hidden states to a decoding mechanism that
repeatedly calls the decoder to generate sequences.
The method currently supports greedy decoding and sampling. See the
documentation of the `Sampler` class for more information about the
parameters related to sampling.
Params:
**encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length)
The sequence to encode.
**decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,)
**device**: (`optional`) `torch.device`
The device on which the prompt_ids will be initialized if not provided.
**length**: (`optional`) int
The length of the sequence to be generated.
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling.
**temperature**: (`optional`) float
The value used to module the next token probabilities.
**k**: (`optional`) int
The parameter used for k-filtering.
**p**: (`optional`) float
The parameter for nucleus sampling. Must be between 0 and 1.
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty.
"""
if
decoder_prompt_ids
is
None
:
decoder_prompt_ids
=
torch
.
tensor
([[]],
dtype
=
torch
.
long
,
device
=
device
)
# When the model does not have a LM head `get_output_embeddings`
# returns `None`. We use this mechanism to determine whether we
# should proceed with decoding or not.
if
self
.
decoder
.
get_output_embeddings
()
is
None
:
raise
AttributeError
(
"You tried do generated sequences with a decoder that does not have a LM Head."
)
# The followings checks that the decoder is on the same device as the one
# that is specified. It only works for models that fit on one GPU.
decoder_device
=
next
(
self
.
decoder
.
parameters
()).
device
if
decoder_device
!=
decoder_prompt_ids
.
device
:
warnings
.
warn
(
"The decoder is not on the same device as the prompt. Expected {}, got {}."
.
format
(
decoder_prompt_ids
.
device
,
decoder_device
)
)
kwargs_encoder
,
kwargs_decoder
=
self
.
prepare_model_kwargs
(
**
kwargs
)
with
torch
.
no_grad
():
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs
)
encoder_hidden_states
=
encoder_outputs
[
0
]
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
sampler_config
=
{
"k"
:
k
,
"p"
:
p
,
"do_sample"
:
do_sample
,
"temperature"
:
temperature
,
"repetition_penalty"
:
repetition_penalty
,
}
return
self
.
_greedy_decode_or_sample
(
decoder_prompt_ids
,
length
,
sampler_config
,
**
kwargs_decoder
)
def
_greedy_decode_or_sample
(
self
,
prompt_ids
,
length
,
sampler_config
,
**
kwargs_decoder
):
sampler
=
Sampler
(
**
sampler_config
)
with
torch
.
no_grad
():
generated_sequence
=
prompt_ids
for
_
in
trange
(
length
):
arguments
=
self
.
decoder
.
_prepare_inputs_for_decoding
(
generated_sequence
,
**
kwargs_decoder
)
outputs
=
self
.
decoder
(
**
arguments
)
next_tokens_logits
=
outputs
[
0
][:,
-
1
,
:]
next_tokens
=
sampler
.
get_one_token
(
next_tokens_logits
,
generated_sequence
)
generated_sequence
=
torch
.
cat
((
generated_sequence
,
next_tokens
),
dim
=
1
)
return
generated_sequence
.
squeeze
(
0
)
@
staticmethod
def
prepare_model_kwargs
(
**
kwargs
):
""" Prepare the encoder and decoder's keyword arguments.
...
...
transformers/modeling_utils.py
View file @
641a8dec
...
...
@@ -494,7 +494,7 @@ class PreTrainedModel(nn.Module):
def
generate
(
self
,
input_ids
=
None
,
max_length
=
None
,
do_sample
=
None
,
num_beams
=
None
,
temperature
=
None
,
top_k
=
None
,
top_p
=
None
,
repetition_penalty
=
None
,
bos_token_id
=
None
,
pad_token_id
=
None
,
eos_token_ids
=
None
,
batch_size
=
None
,
length_penalty
=
None
,
**
kwargs
):
length_penalty
=
None
,
num_return_sequences
=
None
,
**
kwargs
):
""" Sequence generator for models with a LM head.
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
...
...
@@ -526,18 +526,19 @@ class PreTrainedModel(nn.Module):
if
self
.
get_output_embeddings
()
is
None
:
raise
AttributeError
(
"You tried do generated sequences with a model that does not have a LM Head."
)
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
generate_max_length
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
generate_do_sample
num_beams
=
num_beams
if
num_beams
is
not
None
else
self
.
config
.
generate_num_beams
temperature
=
temperature
if
temperature
is
not
None
else
self
.
config
.
generate_temperature
top_k
=
top_k
if
top_k
is
not
None
else
self
.
config
.
generate_top_k
top_p
=
top_p
if
top_p
is
not
None
else
self
.
config
.
generate_top_p
repetition_penalty
=
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
config
.
generate_repetition_penalty
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
generate_bos_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
generate_pad_token_id
eos_token_ids
=
eos_token_ids
if
eos_token_ids
is
not
None
else
self
.
config
.
generate_eos_token_ids
batch_size
=
batch_size
if
batch_size
is
not
None
else
self
.
config
.
generate_batch_size
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
generate_length_penalty
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
do_sample
num_beams
=
num_beams
if
num_beams
is
not
None
else
self
.
config
.
num_beams
temperature
=
temperature
if
temperature
is
not
None
else
self
.
config
.
temperature
top_k
=
top_k
if
top_k
is
not
None
else
self
.
config
.
top_k
top_p
=
top_p
if
top_p
is
not
None
else
self
.
config
.
top_p
repetition_penalty
=
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
config
.
repetition_penalty
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
eos_token_ids
=
eos_token_ids
if
eos_token_ids
is
not
None
else
self
.
config
.
eos_token_ids
batch_size
=
batch_size
if
batch_size
is
not
None
else
self
.
config
.
batch_size
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
num_return_sequences
=
num_return_sequences
if
num_return_sequences
is
not
None
else
self
.
config
.
num_return_sequences
if
input_ids
is
not
None
:
batch_size
=
input_ids
.
shape
[
0
]
# overriden by the input batch_size
...
...
@@ -547,8 +548,8 @@ class PreTrainedModel(nn.Module):
assert
isinstance
(
max_length
,
int
)
and
max_length
>
0
,
"`max_length` should be a strictely positive integer."
assert
isinstance
(
do_sample
,
bool
),
"`do_sample` should be a boolean."
assert
isinstance
(
num_beams
,
int
)
and
num_beams
>
0
,
"`num_beams` should be a strictely positive integer."
assert
temperature
>
0
,
"`temperature` should be positive."
assert
isinstance
(
top_k
,
int
)
and
top_k
>
0
,
"`top_k` should be a
strictely
positive integer."
assert
temperature
>
0
,
"`temperature` should be
strictely
positive."
assert
isinstance
(
top_k
,
int
)
and
top_k
>
=
0
,
"`top_k` should be a positive integer."
assert
0
<=
top_p
<=
1
,
"`top_p` should be between 0 and 1."
assert
repetition_penalty
>=
1.0
,
"`repetition_penalty` should be >= 1."
assert
isinstance
(
bos_token_id
,
int
)
and
bos_token_id
>=
0
,
"`bos_token_id` should be a positive integer."
...
...
@@ -557,30 +558,41 @@ class PreTrainedModel(nn.Module):
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert
isinstance
(
batch_size
,
int
)
and
batch_size
>
0
,
"`batch_size` should be a strictely positive integer."
assert
length_penalty
>
0
,
"`length_penalty` should be strictely positive."
assert
isinstance
(
num_return_sequences
,
int
)
and
num_return_sequences
>
0
,
"`num_return_sequences` should be a strictely positive integer."
if
input_ids
is
None
:
input_ids
=
torch
.
full
((
batch_size
,
1
),
bos_token_id
,
dtype
=
torch
.
long
,
device
=
next
(
self
.
parameters
()).
device
)
else
:
assert
input_ids
.
dim
s
()
==
2
,
"Input prompt should be of shape (batch_size, sequence length)."
assert
input_ids
.
dim
()
==
2
,
"Input prompt should be of shape (batch_size, sequence length)."
# current position and vocab size
cur_len
=
input_ids
.
shape
[
1
]
vocab_size
=
self
.
config
.
vocab_size
if
num_beams
>
1
:
return
self
.
_generate_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
length_penalty
,
num_beams
,
pad_token_id
,
eos_token_ids
,
vocab_size
,
batch_size
)
return
self
.
_generate_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
,
num_return_sequences
,
length_penalty
,
num_beams
,
vocab_size
)
return
self
.
_generate_no_beam_search
(
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
)
pad_token_id
,
eos_token_ids
,
batch_size
,
num_return_sequences
)
def
_generate_no_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
):
""" Generate a sentence without beam search (num_beams == 1). """
pad_token_id
,
eos_token_ids
,
batch_size
,
num_return_sequences
):
""" Generate `num_return_sequences` sequences per batch example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
# Expand input to num return sequences
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_return_sequences
,
cur_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_return_sequences
,
cur_len
)
# (batch_size*num_return_sequences, cur_len)
# current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents
=
input_ids
.
new
(
batch_size
).
fill_
(
1
)
unfinished_sents
=
input_ids
.
new
(
batch_size
*
num_return_sequences
).
fill_
(
1
)
# cache compute states
pasts
=
None
...
...
@@ -592,9 +604,9 @@ class PreTrainedModel(nn.Module):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if
repetition_penalty
!=
1.0
:
for
i
in
range
(
batch_size
):
for
_
in
set
(
input_ids
[
i
].
tolist
()):
next_token_logits
[
i
,
_
]
/=
repetition_penalty
for
i
in
range
(
batch_size
*
num_return_sequences
):
for
previous_tokens
in
set
(
input_ids
[
i
].
tolist
()):
next_token_logits
[
i
,
previous_tokens
]
/=
repetition_penalty
if
do_sample
:
# Temperature (higher temperature => more likely to sample low probability tokens)
...
...
@@ -603,16 +615,16 @@ class PreTrainedModel(nn.Module):
# Top-p/top-k filtering
next_token_logits
=
top_k_top_p_filtering
(
next_token_logits
,
top_k
=
top_k
,
top_p
=
top_p
)
# Sample
next_token
=
torch
.
multinomial
(
F
.
softmax
(
next_token_logits
,
dim
=-
1
),
num_samples
=
1
)
next_token
=
torch
.
multinomial
(
F
.
softmax
(
next_token_logits
,
dim
=-
1
),
num_samples
=
1
)
.
squeeze
(
1
)
else
:
# Greedy decoding
next_token
=
torch
.
argmax
(
next_token_logits
,
dim
=-
1
)
.
unsqueeze
(
-
1
)
next_token
=
torch
.
argmax
(
next_token_logits
,
dim
=-
1
)
# update generations and finished sentences
tokens_to_add
=
next_token
*
unfinished_sents
+
pad_token_id
*
(
1
-
unfinished_sents
)
input_ids
=
torch
.
cat
([
input_ids
,
tokens_to_add
],
dim
=-
1
)
input_ids
=
torch
.
cat
([
input_ids
,
tokens_to_add
.
unsqueeze
(
-
1
)
],
dim
=-
1
)
for
eos_token_id
in
eos_token_ids
:
unfinished_sents
.
mul_
(
tokens_to_add
.
squeeze
(
-
1
).
ne
(
eos_token_id
).
long
())
unfinished_sents
.
mul_
(
tokens_to_add
.
ne
(
eos_token_id
).
long
())
cur_len
=
cur_len
+
1
# stop when there is a </s> in each sentence, or if we exceed the maximul length
...
...
@@ -621,13 +633,24 @@ class PreTrainedModel(nn.Module):
# add eos_token_ids to unfinished sentences
if
cur_len
==
max_length
:
input_ids
[:,
-
1
].
masked_fill_
(
unfinished_sents
.
byte
(),
eos_token_ids
[
0
])
input_ids
[:,
-
1
].
masked_fill_
(
unfinished_sents
.
to
(
dtype
=
torch
.
bool
),
eos_token_ids
[
0
])
if
num_return_sequences
!=
1
:
input_ids
=
input_ids
.
view
(
batch_size
,
num_return_sequences
,
-
1
)
return
input_ids
def
_generate_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
length_penalty
,
num_beams
,
pad_token_id
,
eos_token_ids
,
vocab_size
,
batch_size
):
""" Generate a sentence with beam search. """
def
_generate_beam_search
(
self
,
input_ids
,
cur_len
,
max_length
,
do_sample
,
temperature
,
top_k
,
top_p
,
repetition_penalty
,
pad_token_id
,
eos_token_ids
,
batch_size
,
num_return_sequences
,
length_penalty
,
num_beams
,
vocab_size
):
""" Generate `num_return_sequences` sequences per batch example with beam search.
We return the top-`num_return_sequences` beams.
`num_return_sequences` should be bigger than `num_beams` (we default to the min of both)
"""
num_return_sequences
=
min
(
num_return_sequences
,
num_beams
)
# Expand input to num beams
input_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
batch_size
,
num_beams
,
cur_len
)
input_ids
=
input_ids
.
contiguous
().
view
(
batch_size
*
num_beams
,
cur_len
)
# (batch_size * num_beams, cur_len)
...
...
@@ -638,7 +661,7 @@ class PreTrainedModel(nn.Module):
# scores for each sentence in the beam
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
beam_scores
[:,
1
:]
=
-
1e9
beam_scores
=
beam_scores
.
view
(
-
1
)
beam_scores
=
beam_scores
.
view
(
-
1
)
# shape (batch_size * num_beams,)
# cache compute states
pasts
=
None
# self.prepare_pasts()
...
...
@@ -648,18 +671,40 @@ class PreTrainedModel(nn.Module):
while
cur_len
<
max_length
:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
pasts
=
pasts
)
scores
=
self
(
**
model_inputs
)[
0
]
# (batch_size * num_beams, cur_len, vocab_size)
scores
=
scores
[:,
-
1
,
:]
# (batch_size * num_beams, vocab_size)
scores
=
F
.
log_softmax
(
scores
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
assert
scores
.
size
()
==
(
batch_size
*
num_beams
,
vocab_size
)
scores
=
self
(
**
model_inputs
)[
0
]
# (batch_size * num_beams, cur_len, vocab_size)
scores
=
scores
[:,
-
1
,
:]
# (batch_size * num_beams, vocab_size)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
_scores
=
scores
+
beam_scores
[:,
None
].
expand_as
(
scores
)
# (batch_size * num_beams, vocab_size)
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if
repetition_penalty
!=
1.0
:
for
i
in
range
(
batch_size
*
num_beams
):
for
previous_tokens
in
set
(
input_ids
[
i
].
tolist
()):
scores
[
i
,
previous_tokens
]
/=
repetition_penalty
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
_scores
=
_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
# (batch_size, num_beams * vocab_size)
if
do_sample
:
# Temperature (higher temperature => more likely to sample low probability tokens)
if
temperature
!=
1.0
:
scores
=
scores
/
temperature
# Top-p/top-k filtering
scores
=
top_k_top_p_filtering
(
scores
,
top_k
=
top_k
,
top_p
=
top_p
)
# (batch_size * num_beams, vocab_size)
# Sample 2 next words for each beam (so we have some spare tokens and match output of greedy beam search)
next_words
=
torch
.
multinomial
(
F
.
softmax
(
scores
,
dim
=-
1
),
num_samples
=
2
)
# (batch_size * num_beams, 2)
# Compute next scores
_scores
=
F
.
log_softmax
(
scores
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
_scores
=
torch
.
gather
(
_scores
,
-
1
,
next_words
)
# (batch_size * num_beams, 2)
next_scores
=
_scores
+
beam_scores
[:,
None
].
expand_as
(
_scores
)
# (batch_size * num_beams, 2)
# Match shape of greedy beam search
next_words
=
next_words
.
view
(
batch_size
,
2
*
num_beams
)
# (batch_size, 2 * num_beams)
next_scores
=
next_scores
.
view
(
batch_size
,
2
*
num_beams
)
# (batch_size, 2 * num_beams)
else
:
# do greedy beam search
scores
=
F
.
log_softmax
(
scores
,
dim
=-
1
)
# (batch_size * num_beams, vocab_size)
assert
scores
.
size
()
==
(
batch_size
*
num_beams
,
vocab_size
)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
_scores
=
scores
+
beam_scores
[:,
None
].
expand_as
(
scores
)
# (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
_scores
=
_scores
.
view
(
batch_size
,
num_beams
*
vocab_size
)
# (batch_size, num_beams * vocab_size)
next_scores
,
next_words
=
torch
.
topk
(
_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
next_scores
,
next_words
=
torch
.
topk
(
_scores
,
2
*
num_beams
,
dim
=
1
,
largest
=
True
,
sorted
=
True
)
assert
next_scores
.
size
()
==
next_words
.
size
()
==
(
batch_size
,
2
*
num_beams
)
# next batch beam content
...
...
@@ -733,32 +778,36 @@ class PreTrainedModel(nn.Module):
# print("")
# select the best hypotheses
tgt_len
=
input_ids
.
new
(
batch_size
)
best
=
[]
tgt_len
=
input_ids
.
new
(
batch_size
,
num_return_sequences
)
best
s
=
[]
for
i
,
hypotheses
in
enumerate
(
generated_hyps
):
best_hyp
=
max
(
hypotheses
.
hyp
,
key
=
lambda
x
:
x
[
0
])[
1
]
tgt_len
[
i
]
=
len
(
best_hyp
)
+
1
# +1 for the <EOS> symbol
best
.
append
(
best_hyp
)
best_hyps
=
[
hyp
[
1
]
for
hyp
in
sorted
(
hypotheses
.
hyp
,
key
=
lambda
hyp
:
hyp
[
0
])[
-
num_return_sequences
:]]
for
j
,
hyp
in
enumerate
(
best_hyps
):
tgt_len
[
i
,
j
]
=
len
(
hyp
)
+
1
# +1 for the <EOS> symbol
bests
.
append
(
best_hyps
)
# generate target batch
decoded
=
input_ids
.
new
(
batch_size
,
tgt_len
.
max
().
item
()).
fill_
(
pad_token_id
)
for
i
,
hypo
in
enumerate
(
best
):
decoded
[
i
,
:
tgt_len
[
i
]
-
1
]
=
hypo
decoded
[
i
,
tgt_len
[
i
]
-
1
]
=
eos_token_ids
[
0
]
decoded
=
input_ids
.
new
(
batch_size
,
num_return_sequences
,
tgt_len
.
max
().
item
()).
fill_
(
pad_token_id
)
for
i
,
hyps
in
enumerate
(
bests
):
for
j
,
hypo
in
enumerate
(
hyps
):
decoded
[
i
,
j
,
:
tgt_len
[
i
,
j
]
-
1
]
=
hypo
decoded
[
i
,
j
,
tgt_len
[
i
,
j
]
-
1
]
=
eos_token_ids
[
0
]
if
num_return_sequences
==
1
:
decoded
=
decoded
.
squeeze
(
1
)
# # sanity check
# assert (decoded == eos_token_ids[0]).sum() == 2 * batch_size
return
decoded
def
top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
0
.0
,
filter_value
=-
float
(
'Inf'
)):
def
top_k_top_p_filtering
(
logits
,
top_k
=
0
,
top_p
=
1
.0
,
filter_value
=-
float
(
'Inf'
)):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size x vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p
> 0
.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
if
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if
top_p
< 1
.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
...
...
@@ -768,7 +817,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
filter_value
if
top_p
>
0
.0
:
if
top_p
<
1
.0
:
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
cumulative_probs
=
torch
.
cumsum
(
F
.
softmax
(
sorted_logits
,
dim
=-
1
),
dim
=-
1
)
...
...
transformers/tests/sampling_test.py
deleted
100644 → 0
View file @
77d39720
# coding=utf-8
import
sys
import
unittest
import
numpy
as
np
import
pytest
from
transformers
import
is_torch_available
if
is_torch_available
():
import
torch
from
transformers
import
(
BertConfig
,
BertModel
,
GPT2Config
,
GPT2LMHeadModel
,
OpenAIGPTConfig
,
OpenAIGPTLMHeadModel
,
TransfoXLConfig
,
TransfoXLLMHeadModel
,
XLMConfig
,
XLMWithLMHeadModel
,
XLNetConfig
,
XLNetLMHeadModel
,
Model2Model
,
)
from
transformers.modeling_utils
import
Sampler
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require Torch"
)
class
SamplerTest
(
unittest
.
TestCase
):
def
test_nucleus_sampling
(
self
):
inf
=
-
float
(
"Inf"
)
test_cases
=
(
{
"p"
:
0
,
"logits"
:
torch
.
tensor
([
0.3
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.3
,
0.1
,
0.2
]),
},
{
"p"
:
0.01
,
"logits"
:
torch
.
tensor
([
0.3
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.3
,
inf
,
inf
]),
},
{
"p"
:
1
,
"logits"
:
torch
.
tensor
([
0.3
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.3
,
0.1
,
0.2
]),
},
{
"p"
:
0.2
,
"logits"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.7
,
inf
,
inf
]),
},
{
"p"
:
0.71
,
"logits"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.7
,
inf
,
0.2
]),
},
{
"p"
:
0.71
,
"logits"
:
torch
.
tensor
([
0.1
,
0.7
,
0.2
]),
"expected"
:
torch
.
tensor
([
inf
,
0.7
,
0.2
]),
},
{
"p"
:
0.71
,
"logits"
:
torch
.
tensor
([
0.7
,
0.2
,
0.1
]),
"expected"
:
torch
.
tensor
([
0.7
,
0.2
,
inf
]),
},
{
"p"
:
0.91
,
"logits"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
},
)
for
case
in
test_cases
:
config
=
{
"do_sample"
:
True
,
"temperature"
:
1.0
,
"k"
:
0
,
"p"
:
case
[
"p"
],
"repetition_penalty"
:
1.0
,
}
sampler
=
Sampler
(
**
config
)
filtered_logits
=
sampler
.
apply_nucleus_filter
(
case
[
"logits"
])
np
.
testing
.
assert_array_equal
(
case
[
"expected"
].
numpy
(),
filtered_logits
.
numpy
())
def
test_top_k_filter
(
self
):
inf
=
-
float
(
"Inf"
)
test_cases
=
(
{
"k"
:
0
,
"logits"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
},
{
"k"
:
1
,
"logits"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.7
,
inf
,
inf
]),
},
{
"k"
:
2
,
"logits"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.7
,
inf
,
0.2
]),
},
{
"k"
:
3
,
"logits"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
"expected"
:
torch
.
tensor
([
0.7
,
0.1
,
0.2
]),
},
)
for
case
in
test_cases
:
config
=
{
"do_sample"
:
True
,
"temperature"
:
1.0
,
"k"
:
case
[
"k"
],
"p"
:
0
,
"repetition_penalty"
:
1.0
,
}
sampler
=
Sampler
(
**
config
)
filtered_logits
=
sampler
.
apply_top_k_filter
(
case
[
"logits"
])
np
.
testing
.
assert_array_equal
(
case
[
"expected"
].
numpy
(),
filtered_logits
.
numpy
())
@
pytest
.
mark
.
skipif
(
sys
.
version_info
<
(
3
,
2
),
reason
=
"assertWarns() requires Python >= 3.2"
)
def
test_wrong_k_value
(
self
):
case
=
{
"k"
:
10
,
"vocab_size"
:
5
}
config
=
{
"do_sample"
:
True
,
"temperature"
:
1.0
,
"k"
:
case
[
"k"
],
"p"
:
0
,
"repetition_penalty"
:
1.0
,
}
sampler
=
Sampler
(
**
config
)
next_token_logits
=
torch
.
rand
(
case
[
"vocab_size"
]).
unsqueeze
(
0
)
past_sequence
=
torch
.
tensor
([])
with
self
.
assertWarns
(
UserWarning
):
_
=
sampler
.
get_one_token
(
next_token_logits
,
past_sequence
)
def
test_zero_temperature
(
self
):
temperature
=
0
config
=
{
"do_sample"
:
True
,
"temperature"
:
temperature
,
"k"
:
0
,
"p"
:
0
,
"repetition_penalty"
:
1.0
,
}
sampler
=
Sampler
(
**
config
)
next_token_logits
=
torch
.
rand
(
10
).
unsqueeze
(
0
)
past_sequence
=
torch
.
tensor
([])
with
self
.
assertRaises
(
ZeroDivisionError
):
_
=
sampler
.
get_one_token
(
next_token_logits
,
past_sequence
)
class
SamplerSingleStackTest
(
unittest
.
TestCase
):
def
test_raises_exception_when_no_LM_head
(
self
):
models
=
[
BertModel
(
BertConfig
())]
for
model
in
models
:
with
self
.
assertRaises
(
AttributeError
):
model
.
decode
()
@
pytest
.
mark
.
slow
def
test_forward_pass_and_output_length
(
self
):
models
=
{
"XLNet"
:
XLNetLMHeadModel
(
XLNetConfig
()),
"XLM"
:
XLMWithLMHeadModel
(
XLMConfig
()),
"TransfoXL"
:
TransfoXLLMHeadModel
(
TransfoXLConfig
()),
"GPT2"
:
GPT2LMHeadModel
(
GPT2Config
()),
"GPT"
:
OpenAIGPTLMHeadModel
(
OpenAIGPTConfig
()),
}
kwargs
=
{
"XLNet"
:
{},
"XLM"
:
{
"mask_token"
:
0
},
"TransfoXL"
:
{},
"GPT2"
:
{},
"GPT"
:
{},
}
prompt
=
torch
.
tensor
([[
1
,
2
,
3
]],
dtype
=
torch
.
long
)
generated_length
=
5
expected_length
=
8
for
name
,
model
in
models
.
items
():
kwargs_model
=
kwargs
[
name
]
output
=
model
.
decode
(
prompt_ids
=
prompt
,
length
=
generated_length
,
**
kwargs_model
)
self
.
assertEqual
(
len
(
output
),
expected_length
)
class
SamplerEncoderDecoderTest
(
unittest
.
TestCase
):
@
pytest
.
mark
.
slow
def
test_forward_pass_and_output_length
(
self
):
model
=
Model2Model
.
from_pretrained
(
"bert-base-uncased"
)
encoder_input_ids
=
torch
.
tensor
([[
1
,
2
,
3
]],
dtype
=
torch
.
long
)
prompt
=
torch
.
tensor
([[
1
,
2
,
3
]],
dtype
=
torch
.
long
)
generated_length
=
5
expected_length
=
8
output
=
model
.
decode
(
encoder_input_ids
,
decoder_prompt_ids
=
prompt
,
k
=
2
,
p
=
0.5
,
repetition_penalty
=
2
,
length
=
generated_length
,
)
self
.
assertEqual
(
len
(
output
),
expected_length
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment