Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a468870f
Commit
a468870f
authored
Dec 16, 2019
by
thomwolf
Browse files
refactoring generation
parent
07bc8efb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
213 additions
and
227 deletions
+213
-227
transformers/configuration_utils.py
transformers/configuration_utils.py
+11
-0
transformers/modeling_utils.py
transformers/modeling_utils.py
+202
-227
No files found.
transformers/configuration_utils.py
View file @
a468870f
...
...
@@ -57,8 +57,19 @@ class PretrainedConfig(object):
self
.
torchscript
=
kwargs
.
pop
(
'torchscript'
,
False
)
# Only used by PyTorch models
self
.
use_bfloat16
=
kwargs
.
pop
(
'use_bfloat16'
,
False
)
self
.
pruned_heads
=
kwargs
.
pop
(
'pruned_heads'
,
{})
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self
.
is_decoder
=
kwargs
.
pop
(
'is_decoder'
,
False
)
# Parameters for sequence generation
self
.
generate_length
=
kwargs
.
pop
(
'generate_length'
,
10
)
self
.
generate_do_sample
=
kwargs
.
pop
(
'generate_do_sample'
,
False
)
self
.
generate_num_beams
=
kwargs
.
pop
(
'generate_num_beams'
,
1
)
self
.
generate_temperature
=
kwargs
.
pop
(
'generate_temperature'
,
1.0
)
self
.
generate_top_k
=
kwargs
.
pop
(
'generate_top_k'
,
50
)
self
.
generate_top_p
=
kwargs
.
pop
(
'generate_top_p'
,
0.0
)
self
.
generate_repetition_penalty
=
kwargs
.
pop
(
'generate_repetition_penalty'
,
1.0
)
def
save_pretrained
(
self
,
save_directory
):
""" Save a configuration object to the directory `save_directory`, so that it
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
...
...
transformers/modeling_utils.py
View file @
a468870f
...
...
@@ -82,6 +82,7 @@ class PreTrainedModel(nn.Module):
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
# Save config in model
self
.
config
=
config
...
...
@@ -89,93 +90,6 @@ class PreTrainedModel(nn.Module):
def
base_model
(
self
):
return
getattr
(
self
,
self
.
base_model_prefix
,
self
)
def
decode
(
self
,
prompt_ids
=
None
,
device
=
torch
.
device
(
'cpu'
),
length
=
10
,
do_sample
=
False
,
temperature
=
1.
,
k
=
9
,
p
=
0
,
repetition_penalty
=
1
,
**
model_kwargs
):
""" Generic sequence generator for single-stack models with a LM head.
The method currently supports greedy decoding and sampling. See the
documentation of the `Sampler` class for more information about the
parameters related to sampling.
Params:
**encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length)
The sequence to encode.
**decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,)
**device**: (`optional`) `torch.device`
The device on which the prompt_ids will be initialized if not provided.
**length**: (`optional`) int
The length of the sequence to be generated.
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling.
**temperature**: (`optional`) float
The value used to module the next token probabilities.
**k**: (`optional`) int
The parameter used for k-filtering.
**p**: (`optional`) float
The parameter for nucleus sampling. Must be between 0 and 1.
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty.
"""
if
prompt_ids
is
None
:
prompt_ids
=
torch
.
tensor
([[]],
dtype
=
torch
.
long
,
device
=
device
)
# When the model does not have a LM head `get_output_embeddings`
# returns `None`. We use this mechanism to determine whether we
# should proceed with decoding or not.
if
self
.
get_output_embeddings
()
is
None
:
raise
AttributeError
(
"You tried do generated sequences with a model that does not have a LM Head."
)
# The followings checks that the model is on the same device as the one
# that is specified. It only works for models that fit on one GPU.
model_device
=
next
(
self
.
parameters
()).
device
if
model_device
!=
prompt_ids
.
device
:
warnings
.
warn
(
"The model is not on the same device as the prompts. Expected {}, got {}."
.
format
(
prompt_ids
.
device
,
model_device
)
)
sampler_config
=
{
"k"
:
k
,
"p"
:
p
,
"do_sample"
:
do_sample
,
"temperature"
:
temperature
,
"repetition_penalty"
:
repetition_penalty
,
}
return
self
.
_greedy_decode_or_sample
(
prompt_ids
,
length
,
sampler_config
,
**
model_kwargs
)
def
_greedy_decode_or_sample
(
self
,
prompt_ids
,
length
,
sampler_config
,
**
model_kwargs
):
""" Generate text using greedy decoding or by sampling tokens."""
sampler
=
Sampler
(
**
sampler_config
)
generated_sequence
=
prompt_ids
with
torch
.
no_grad
():
for
_
in
trange
(
length
):
arguments
=
self
.
_prepare_inputs_for_decoding
(
generated_sequence
,
**
model_kwargs
)
outputs
=
self
(
**
arguments
)
next_tokens_logits
=
outputs
[
0
][:,
-
1
,
:]
next_tokens
=
sampler
.
get_one_token
(
next_tokens_logits
,
generated_sequence
)
generated_sequence
=
torch
.
cat
((
generated_sequence
,
next_tokens
),
dim
=
1
)
return
generated_sequence
.
squeeze
(
0
)
def
_prepare_inputs_for_decoding
(
self
,
input_ids
,
**
kwargs
):
arguments
=
{
"input_ids"
:
input_ids
}
arguments
.
update
(
kwargs
)
return
arguments
def
get_input_embeddings
(
self
):
""" Get model's input embeddings
"""
...
...
@@ -306,6 +220,9 @@ class PreTrainedModel(nn.Module):
# Tie weights if needed
self
.
tie_weights
()
# Initialize decoding head if we have output embeddings
def
prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the base model.
...
...
@@ -571,6 +488,204 @@ class PreTrainedModel(nn.Module):
return
model
def
generate
(
self
,
input_ids
=
None
,
length
=
None
,
do_sample
=
False
,
num_beams
=
None
,
temperature
=
None
,
top_k
=
None
,
top_p
=
None
,
repetition_penalty
=
None
,
**
model_kwargs
):
""" Generic sequence generator for single-stack models with a LM head.
The method currently supports greedy decoding and sampling. See the
documentation of the `Sampler` class for more information about the
parameters related to sampling.
Params:
**input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,)
**length**: (`optional`) int
The length of the sequence to be generated.
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling.
**temperature**: (`optional`) float
The value used to module the next token probabilities.
**k**: (`optional`) int
The parameter used for k-filtering.
**p**: (`optional`) float
The parameter for nucleus sampling. Must be between 0 and 1.
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty.
"""
if
input_ids
is
None
:
input_ids
=
torch
.
tensor
([[]],
dtype
=
torch
.
long
,
device
=
next
(
self
.
parameters
()).
device
)
# We cannot generate if the model does not have a LM head
if
self
.
get_output_embeddings
()
is
None
:
raise
AttributeError
(
"You tried do generated sequences with a model that does not have a LM Head."
)
sampler_config
=
{
"k"
:
k
,
"p"
:
p
,
"do_sample"
:
do_sample
,
"temperature"
:
temperature
,
"repetition_penalty"
:
repetition_penalty
,
}
sampler
=
Sampler
(
**
sampler_config
)
generated_sequence
=
input_ids
for
_
in
trange
(
length
):
arguments
=
self
.
_prepare_inputs_for_decoding
(
generated_sequence
,
**
model_kwargs
)
outputs
=
self
(
**
arguments
)
next_tokens_logits
=
outputs
[
0
][:,
-
1
,
:]
next_tokens
=
sampler
.
get_one_token
(
next_tokens_logits
,
generated_sequence
)
generated_sequence
=
torch
.
cat
((
generated_sequence
,
next_tokens
),
dim
=
1
)
return
generated_sequence
.
squeeze
(
0
)
def
_prepare_inputs_for_decoding
(
self
,
input_ids
,
**
model_kwargs
):
return
model_kwargs
.
update
({
"input_ids"
:
input_ids
})
class
Sampler
(
object
):
r
""" Sampler is used to generate sequences of ids from logit inputs.
Greedy decoding, which consists in chosing the most probable token at each
step, is the default behaviour. Sampling with varying temperature, top_k
and nucleus filtering is also implemented.
Attributes:
**device**: ``torch.device``
Device on which the computations will be run.
**do_sample**: bool
Whether to sample or do greedy decoding.
**k**: int between 0 and vocab_size
Parameter for the top-k filtering
**p**: float between 0 and 1
Parameter for the nucleus filtering
**temperature**: strictly positive float
Parameter used to modulate the distribution over ids. Low temperatures
put more emphasis on highly probably token while high temperatures tend
to smooth the probability distribution.
**repetition_penalty**: strictly postitive float
The penalty applied to repeating ids
"""
def
__init__
(
self
,
do_sample
=
False
,
k
=
9
,
p
=
0.0
,
temperature
=
1.0
,
repetition_penalty
=
1.0
):
self
.
k
=
k
self
.
p
=
p
self
.
do_sample
=
do_sample
self
.
temperature
=
temperature
self
.
repetition_penalty
=
repetition_penalty
self
.
do_apply_repetition_penalty
=
True
if
repetition_penalty
>
1
else
False
if
self
.
p
>
1
:
warnings
.
warn
(
"""You are trying to apply nucleus filtering with a value of p greater than 1 ({}).
However p is a probability and its value must lie between 0 and 1. In effect, no filtering
will be applied. If this is not the behavior you expect, change the value of p."""
.
format
(
self
.
p
)
)
def
get_one_token
(
self
,
next_token_logits
,
past_sequence
):
logits
=
self
.
apply_repetition_penalty
(
next_token_logits
,
past_sequence
)
if
self
.
do_sample
:
logits
=
self
.
apply_temperature
(
logits
)
logits
=
self
.
apply_top_k_filter
(
logits
)
logits
=
self
.
apply_nucleus_filter
(
logits
)
return
torch
.
multinomial
(
F
.
softmax
(
logits
,
dim
=-
1
),
num_samples
=
1
)
return
torch
.
argmax
(
logits
,
dim
=-
1
).
unsqueeze
(
-
1
)
def
apply_repetition_penalty
(
self
,
logits
,
past_sequence
):
""" Apply a penalty to tokens that appear more than once in the
generated sequence.
.. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer
language model for controllable generation." arXiv preprint
arXiv:1909.05858 (2019).
"""
if
self
.
do_apply_repetition_penalty
:
generated_token_idx
=
set
(
past_sequence
[
0
].
tolist
())
for
token_idx
in
generated_token_idx
:
logits
[
0
,
token_idx
]
/=
self
.
repetition_penalty
return
logits
def
apply_temperature
(
self
,
logits
):
""" Shape the tokens' distribution through temperature. The higher the value
of the temperature, the more skewed towards high probability events the
distribution is.
.. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning.
MIT press, 2016.
"""
# when dividing a float by 0, torch returns inf which in turns breaks the
# multinomial with an error message that is not very helpful. It is better
# for the user to break the execution and explain why.
if
self
.
temperature
==
0
:
raise
ZeroDivisionError
(
"""You are trying to sample with a temperature equal to 0.
If you wanted to do greedy sampling, set instead `do_sample` to False.
Otherwise set the temperature to a value different from 0."""
)
return
logits
/
self
.
temperature
def
apply_top_k_filter
(
self
,
logits
):
""" Use the probability distribution of the tokens to determine the set
to be sampled from. Specifically we select the set of size k such that
the sum of its items' probabilities is maximum.
.. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural
story generation." arXiv preprint arXiv:1805.04833 (2018).
"""
if
self
.
k
>
0
:
vocabulary_size
=
logits
.
size
(
-
1
)
if
self
.
k
>
vocabulary_size
:
warnings
.
warn
(
"""You provided a value for k ({}) that is larger than the vocabulary size ({}).
We adjusted k's value to the vocabulary size; if that was what you intended to do
we recommend setting k to 0 instead. It this is not the behavior you expected,
choose a value of k that is smaller than the vocabulary size."""
.
format
(
self
.
k
,
vocabulary_size
)
)
self
.
k
=
vocabulary_size
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
self
.
k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
-
float
(
"Inf"
)
return
logits
def
apply_nucleus_filter
(
self
,
logits
):
""" Use the probability distribution of the tokens to determine the set
to be sampled from. Specifically, choose the smallest set such that the
sum of its items' probabilities is greater than a number p in [0,1].
.. Holtzman, Ari, et al. "The curious case of neural text
degeneration." arXiv preprint arXiv:1904.09751 (2019).
"""
if
self
.
p
>
0
:
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
sorted_probabilities
=
F
.
softmax
(
sorted_logits
,
dim
=-
1
)
cumulative_probabilities
=
torch
.
cumsum
(
sorted_probabilities
,
dim
=-
1
)
# Remove tokens with cumulative probability above the threshold,
# but keep the first token above the threshold.
sorted_indices_to_remove
=
cumulative_probabilities
>
self
.
p
sorted_indices_to_remove
[...,
1
:]
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
0
]
=
0
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
dim
=-
1
,
index
=
sorted_indices
,
src
=
sorted_indices_to_remove
)
logits
[
indices_to_remove
]
=
-
float
(
"Inf"
)
return
logits
class
Conv1D
(
nn
.
Module
):
def
__init__
(
self
,
nf
,
nx
):
...
...
@@ -948,143 +1063,3 @@ def prune_layer(layer, index, dim=None):
return
prune_conv1d_layer
(
layer
,
index
,
dim
=
1
if
dim
is
None
else
dim
)
else
:
raise
ValueError
(
"Can't prune layer of class {}"
.
format
(
layer
.
__class__
))
class
Sampler
(
object
):
r
""" Sampler is used to generate sequences of ids from logit inputs.
Greedy decoding, which consists in chosing the most probable token at each
step, is the default behaviour. Sampling with varying temperature, top_k
and nucleus filtering is also implemented.
Attributes:
**device**: ``torch.device``
Device on which the computations will be run.
**do_sample**: bool
Whether to sample or do greedy decoding.
**k**: int between 0 and vocab_size
Parameter for the top-k filtering
**p**: float between 0 and 1
Parameter for the nucleus filtering
**temperature**: strictly positive float
Parameter used to modulate the distribution over ids. Low temperatures
put more emphasis on highly probably token while high temperatures tend
to smooth the probability distribution.
**repetition_penalty**: strictly postitive float
The penalty applied to repeating ids
"""
def
__init__
(
self
,
do_sample
=
False
,
k
=
9
,
p
=
0.0
,
temperature
=
1.0
,
repetition_penalty
=
1.0
):
self
.
k
=
k
self
.
p
=
p
self
.
do_sample
=
do_sample
self
.
temperature
=
temperature
self
.
repetition_penalty
=
repetition_penalty
self
.
do_apply_repetition_penalty
=
True
if
repetition_penalty
>
1
else
False
if
self
.
p
>
1
:
warnings
.
warn
(
"""You are trying to apply nucleus filtering with a value of p greater than 1 ({}).
However p is a probability and its value must lie between 0 and 1. In effect, no filtering
will be applied. If this is not the behavior you expect, change the value of p."""
.
format
(
self
.
p
)
)
def
get_one_token
(
self
,
next_token_logits
,
past_sequence
):
logits
=
self
.
apply_repetition_penalty
(
next_token_logits
,
past_sequence
)
if
self
.
do_sample
:
logits
=
self
.
apply_temperature
(
logits
)
logits
=
self
.
apply_top_k_filter
(
logits
)
logits
=
self
.
apply_nucleus_filter
(
logits
)
return
torch
.
multinomial
(
F
.
softmax
(
logits
,
dim
=-
1
),
num_samples
=
1
)
return
torch
.
argmax
(
logits
,
dim
=-
1
).
unsqueeze
(
-
1
)
def
apply_repetition_penalty
(
self
,
logits
,
past_sequence
):
""" Apply a penalty to tokens that appear more than once in the
generated sequence.
.. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer
language model for controllable generation." arXiv preprint
arXiv:1909.05858 (2019).
"""
if
self
.
do_apply_repetition_penalty
:
generated_token_idx
=
set
(
past_sequence
[
0
].
tolist
())
for
token_idx
in
generated_token_idx
:
logits
[
0
,
token_idx
]
/=
self
.
repetition_penalty
return
logits
def
apply_temperature
(
self
,
logits
):
""" Shape the tokens' distribution through temperature. The higher the value
of the temperature, the more skewed towards high probability events the
distribution is.
.. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning.
MIT press, 2016.
"""
# when dividing a float by 0, torch returns inf which in turns breaks the
# multinomial with an error message that is not very helpful. It is better
# for the user to break the execution and explain why.
if
self
.
temperature
==
0
:
raise
ZeroDivisionError
(
"""You are trying to sample with a temperature equal to 0.
If you wanted to do greedy sampling, set instead `do_sample` to False.
Otherwise set the temperature to a value different from 0."""
)
return
logits
/
self
.
temperature
def
apply_top_k_filter
(
self
,
logits
):
""" Use the probability distribution of the tokens to determine the set
to be sampled from. Specifically we select the set of size k such that
the sum of its items' probabilities is maximum.
.. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural
story generation." arXiv preprint arXiv:1805.04833 (2018).
"""
if
self
.
k
>
0
:
vocabulary_size
=
logits
.
size
(
-
1
)
if
self
.
k
>
vocabulary_size
:
warnings
.
warn
(
"""You provided a value for k ({}) that is larger than the vocabulary size ({}).
We adjusted k's value to the vocabulary size; if that was what you intended to do
we recommend setting k to 0 instead. It this is not the behavior you expected,
choose a value of k that is smaller than the vocabulary size."""
.
format
(
self
.
k
,
vocabulary_size
)
)
self
.
k
=
vocabulary_size
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
self
.
k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
-
float
(
"Inf"
)
return
logits
def
apply_nucleus_filter
(
self
,
logits
):
""" Use the probability distribution of the tokens to determine the set
to be sampled from. Specifically, choose the smallest set such that the
sum of its items' probabilities is greater than a number p in [0,1].
.. Holtzman, Ari, et al. "The curious case of neural text
degeneration." arXiv preprint arXiv:1904.09751 (2019).
"""
if
self
.
p
>
0
:
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
sorted_probabilities
=
F
.
softmax
(
sorted_logits
,
dim
=-
1
)
cumulative_probabilities
=
torch
.
cumsum
(
sorted_probabilities
,
dim
=-
1
)
# Remove tokens with cumulative probability above the threshold,
# but keep the first token above the threshold.
sorted_indices_to_remove
=
cumulative_probabilities
>
self
.
p
sorted_indices_to_remove
[...,
1
:]
=
sorted_indices_to_remove
[...,
:
-
1
].
clone
()
sorted_indices_to_remove
[...,
0
]
=
0
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
dim
=-
1
,
index
=
sorted_indices
,
src
=
sorted_indices_to_remove
)
logits
[
indices_to_remove
]
=
-
float
(
"Inf"
)
return
logits
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment