Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
95e00d08
Unverified
Commit
95e00d08
authored
Mar 20, 2020
by
Patrick von Platen
Committed by
GitHub
Mar 20, 2020
Browse files
Clean special token init in modeling_....py (#3264)
* make style * fix conflicts
parent
8becb732
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
115 additions
and
113 deletions
+115
-113
examples/summarization/bart/evaluate_cnn.py
examples/summarization/bart/evaluate_cnn.py
+1
-1
src/transformers/__init__.py
src/transformers/__init__.py
+1
-0
src/transformers/configuration_albert.py
src/transformers/configuration_albert.py
+4
-1
src/transformers/configuration_bart.py
src/transformers/configuration_bart.py
+4
-4
src/transformers/configuration_bert.py
src/transformers/configuration_bert.py
+2
-1
src/transformers/configuration_distilbert.py
src/transformers/configuration_distilbert.py
+2
-1
src/transformers/configuration_flaubert.py
src/transformers/configuration_flaubert.py
+2
-2
src/transformers/configuration_gpt2.py
src/transformers/configuration_gpt2.py
+2
-2
src/transformers/configuration_roberta.py
src/transformers/configuration_roberta.py
+5
-0
src/transformers/configuration_t5.py
src/transformers/configuration_t5.py
+2
-2
src/transformers/configuration_transfo_xl.py
src/transformers/configuration_transfo_xl.py
+1
-3
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+1
-1
src/transformers/configuration_xlm.py
src/transformers/configuration_xlm.py
+2
-5
src/transformers/configuration_xlnet.py
src/transformers/configuration_xlnet.py
+3
-3
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+3
-3
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+39
-41
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+34
-36
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+5
-5
tests/test_modeling_gpt2.py
tests/test_modeling_gpt2.py
+1
-1
tests/test_modeling_tf_gpt2.py
tests/test_modeling_tf_gpt2.py
+1
-1
No files found.
examples/summarization/bart/evaluate_cnn.py
View file @
95e00d08
...
@@ -35,7 +35,7 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
...
@@ -35,7 +35,7 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
min_length
=
min_length
+
1
,
# +1 from original because we start at step=1
min_length
=
min_length
+
1
,
# +1 from original because we start at step=1
no_repeat_ngram_size
=
3
,
no_repeat_ngram_size
=
3
,
early_stopping
=
True
,
early_stopping
=
True
,
decoder_start_token_id
=
model
.
config
.
eos_token_id
s
[
0
]
,
decoder_start_token_id
=
model
.
config
.
eos_token_id
,
)
)
dec
=
[
tokenizer
.
decode
(
g
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
for
g
in
summaries
]
dec
=
[
tokenizer
.
decode
(
g
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
)
for
g
in
summaries
]
for
hypothesis
in
dec
:
for
hypothesis
in
dec
:
...
...
src/transformers/__init__.py
View file @
95e00d08
...
@@ -223,6 +223,7 @@ if is_torch_available():
...
@@ -223,6 +223,7 @@ if is_torch_available():
BartForSequenceClassification
,
BartForSequenceClassification
,
BartModel
,
BartModel
,
BartForConditionalGeneration
,
BartForConditionalGeneration
,
BART_PRETRAINED_MODEL_ARCHIVE_MAP
,
)
)
from
.modeling_roberta
import
(
from
.modeling_roberta
import
(
RobertaForMaskedLM
,
RobertaForMaskedLM
,
...
...
src/transformers/configuration_albert.py
View file @
95e00d08
...
@@ -124,9 +124,12 @@ class AlbertConfig(PretrainedConfig):
...
@@ -124,9 +124,12 @@ class AlbertConfig(PretrainedConfig):
initializer_range
=
0.02
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
,
layer_norm_eps
=
1e-12
,
classifier_dropout_prob
=
0.1
,
classifier_dropout_prob
=
0.1
,
pad_token_id
=
0
,
bos_token_id
=
2
,
eos_token_id
=
3
,
**
kwargs
**
kwargs
):
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
embedding_size
=
embedding_size
self
.
embedding_size
=
embedding_size
...
...
src/transformers/configuration_bart.py
View file @
95e00d08
...
@@ -41,9 +41,6 @@ class BartConfig(PretrainedConfig):
...
@@ -41,9 +41,6 @@ class BartConfig(PretrainedConfig):
activation_dropout
=
0.0
,
activation_dropout
=
0.0
,
activation_function
=
"gelu"
,
activation_function
=
"gelu"
,
vocab_size
=
50265
,
vocab_size
=
50265
,
bos_token_id
=
0
,
pad_token_id
=
1
,
eos_token_ids
=
[
2
],
d_model
=
1024
,
d_model
=
1024
,
encoder_ffn_dim
=
4096
,
encoder_ffn_dim
=
4096
,
encoder_layers
=
12
,
encoder_layers
=
12
,
...
@@ -61,6 +58,9 @@ class BartConfig(PretrainedConfig):
...
@@ -61,6 +58,9 @@ class BartConfig(PretrainedConfig):
output_past
=
False
,
output_past
=
False
,
num_labels
=
3
,
num_labels
=
3
,
is_encoder_decoder
=
True
,
is_encoder_decoder
=
True
,
pad_token_id
=
1
,
bos_token_id
=
0
,
eos_token_id
=
2
,
**
common_kwargs
**
common_kwargs
):
):
r
"""
r
"""
...
@@ -74,7 +74,7 @@ class BartConfig(PretrainedConfig):
...
@@ -74,7 +74,7 @@ class BartConfig(PretrainedConfig):
output_past
=
output_past
,
output_past
=
output_past
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
s
=
eos_token_id
s
,
eos_token_id
=
eos_token_id
,
is_encoder_decoder
=
is_encoder_decoder
,
is_encoder_decoder
=
is_encoder_decoder
,
**
common_kwargs
,
**
common_kwargs
,
)
)
...
...
src/transformers/configuration_bert.py
View file @
95e00d08
...
@@ -124,9 +124,10 @@ class BertConfig(PretrainedConfig):
...
@@ -124,9 +124,10 @@ class BertConfig(PretrainedConfig):
type_vocab_size
=
2
,
type_vocab_size
=
2
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
,
layer_norm_eps
=
1e-12
,
pad_token_id
=
0
,
**
kwargs
**
kwargs
):
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
pad_token_id
=
pad_token_id
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
...
src/transformers/configuration_distilbert.py
View file @
95e00d08
...
@@ -113,9 +113,10 @@ class DistilBertConfig(PretrainedConfig):
...
@@ -113,9 +113,10 @@ class DistilBertConfig(PretrainedConfig):
initializer_range
=
0.02
,
initializer_range
=
0.02
,
qa_dropout
=
0.1
,
qa_dropout
=
0.1
,
seq_classif_dropout
=
0.2
,
seq_classif_dropout
=
0.2
,
pad_token_id
=
0
,
**
kwargs
**
kwargs
):
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
,
pad_token_id
=
pad_token_id
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
sinusoidal_pos_embds
=
sinusoidal_pos_embds
self
.
sinusoidal_pos_embds
=
sinusoidal_pos_embds
...
...
src/transformers/configuration_flaubert.py
View file @
95e00d08
...
@@ -145,9 +145,9 @@ class FlaubertConfig(XLMConfig):
...
@@ -145,9 +145,9 @@ class FlaubertConfig(XLMConfig):
pretrained_config_archive_map
=
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map
=
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type
=
"flaubert"
model_type
=
"flaubert"
def
__init__
(
self
,
layerdrop
=
0.0
,
pre_norm
=
False
,
**
kwargs
):
def
__init__
(
self
,
layerdrop
=
0.0
,
pre_norm
=
False
,
pad_token_id
=
2
,
bos_token_id
=
0
,
**
kwargs
):
"""Constructs FlaubertConfig.
"""Constructs FlaubertConfig.
"""
"""
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
**
kwargs
)
self
.
layerdrop
=
layerdrop
self
.
layerdrop
=
layerdrop
self
.
pre_norm
=
pre_norm
self
.
pre_norm
=
pre_norm
src/transformers/configuration_gpt2.py
View file @
95e00d08
...
@@ -142,7 +142,7 @@ class GPT2Config(PretrainedConfig):
...
@@ -142,7 +142,7 @@ class GPT2Config(PretrainedConfig):
eos_token_id
=
50256
,
eos_token_id
=
50256
,
**
kwargs
**
kwargs
):
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
n_ctx
=
n_ctx
self
.
n_ctx
=
n_ctx
...
@@ -163,7 +163,7 @@ class GPT2Config(PretrainedConfig):
...
@@ -163,7 +163,7 @@ class GPT2Config(PretrainedConfig):
self
.
summary_proj_to_labels
=
summary_proj_to_labels
self
.
summary_proj_to_labels
=
summary_proj_to_labels
self
.
bos_token_id
=
bos_token_id
self
.
bos_token_id
=
bos_token_id
self
.
eos_token_id
s
=
[
eos_token_id
]
self
.
eos_token_id
=
eos_token_id
@
property
@
property
def
max_position_embeddings
(
self
):
def
max_position_embeddings
(
self
):
...
...
src/transformers/configuration_roberta.py
View file @
95e00d08
...
@@ -66,3 +66,8 @@ class RobertaConfig(BertConfig):
...
@@ -66,3 +66,8 @@ class RobertaConfig(BertConfig):
"""
"""
pretrained_config_archive_map
=
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
pretrained_config_archive_map
=
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type
=
"roberta"
model_type
=
"roberta"
def
__init__
(
self
,
pad_token_id
=
1
,
bos_token_id
=
0
,
eos_token_id
=
2
,
**
kwargs
):
"""Constructs FlaubertConfig.
"""
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
**
kwargs
)
src/transformers/configuration_t5.py
View file @
95e00d08
...
@@ -77,11 +77,11 @@ class T5Config(PretrainedConfig):
...
@@ -77,11 +77,11 @@ class T5Config(PretrainedConfig):
initializer_factor
=
1.0
,
initializer_factor
=
1.0
,
is_encoder_decoder
=
True
,
is_encoder_decoder
=
True
,
pad_token_id
=
0
,
pad_token_id
=
0
,
eos_token_id
s
=
[
1
]
,
eos_token_id
=
1
,
**
kwargs
**
kwargs
):
):
super
().
__init__
(
super
().
__init__
(
is_encoder_decoder
=
is_encoder_decoder
,
**
kwargs
,
pad_token_id
=
pad_token_id
,
eos_token_id
=
eos_token_id
,
is_encoder_decoder
=
is_encoder_decoder
,
**
kwargs
,
)
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
n_positions
=
n_positions
self
.
n_positions
=
n_positions
...
...
src/transformers/configuration_transfo_xl.py
View file @
95e00d08
...
@@ -152,7 +152,7 @@ class TransfoXLConfig(PretrainedConfig):
...
@@ -152,7 +152,7 @@ class TransfoXLConfig(PretrainedConfig):
eos_token_id
=
0
,
eos_token_id
=
0
,
**
kwargs
**
kwargs
):
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
eos_token_id
=
eos_token_id
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
[]
self
.
cutoffs
=
[]
...
@@ -187,8 +187,6 @@ class TransfoXLConfig(PretrainedConfig):
...
@@ -187,8 +187,6 @@ class TransfoXLConfig(PretrainedConfig):
self
.
init_std
=
init_std
self
.
init_std
=
init_std
self
.
layer_norm_epsilon
=
layer_norm_epsilon
self
.
layer_norm_epsilon
=
layer_norm_epsilon
self
.
eos_token_ids
=
[
eos_token_id
]
@
property
@
property
def
max_position_embeddings
(
self
):
def
max_position_embeddings
(
self
):
return
self
.
tgt_len
+
self
.
ext_len
+
self
.
mem_len
return
self
.
tgt_len
+
self
.
ext_len
+
self
.
mem_len
...
...
src/transformers/configuration_utils.py
View file @
95e00d08
...
@@ -80,7 +80,7 @@ class PretrainedConfig(object):
...
@@ -80,7 +80,7 @@ class PretrainedConfig(object):
self
.
repetition_penalty
=
kwargs
.
pop
(
"repetition_penalty"
,
1.0
)
self
.
repetition_penalty
=
kwargs
.
pop
(
"repetition_penalty"
,
1.0
)
self
.
bos_token_id
=
kwargs
.
pop
(
"bos_token_id"
,
None
)
self
.
bos_token_id
=
kwargs
.
pop
(
"bos_token_id"
,
None
)
self
.
pad_token_id
=
kwargs
.
pop
(
"pad_token_id"
,
None
)
self
.
pad_token_id
=
kwargs
.
pop
(
"pad_token_id"
,
None
)
self
.
eos_token_id
s
=
kwargs
.
pop
(
"eos_token_id
s
"
,
None
)
self
.
eos_token_id
=
kwargs
.
pop
(
"eos_token_id"
,
None
)
self
.
length_penalty
=
kwargs
.
pop
(
"length_penalty"
,
1.0
)
self
.
length_penalty
=
kwargs
.
pop
(
"length_penalty"
,
1.0
)
self
.
no_repeat_ngram_size
=
kwargs
.
pop
(
"no_repeat_ngram_size"
,
0
)
self
.
no_repeat_ngram_size
=
kwargs
.
pop
(
"no_repeat_ngram_size"
,
0
)
self
.
num_return_sequences
=
kwargs
.
pop
(
"num_return_sequences"
,
1
)
self
.
num_return_sequences
=
kwargs
.
pop
(
"num_return_sequences"
,
1
)
...
...
src/transformers/configuration_xlm.py
View file @
95e00d08
...
@@ -194,13 +194,13 @@ class XLMConfig(PretrainedConfig):
...
@@ -194,13 +194,13 @@ class XLMConfig(PretrainedConfig):
end_n_top
=
5
,
end_n_top
=
5
,
mask_token_id
=
0
,
mask_token_id
=
0
,
lang_id
=
0
,
lang_id
=
0
,
bos_token_id
=
0
,
pad_token_id
=
2
,
pad_token_id
=
2
,
bos_token_id
=
0
,
**
kwargs
**
kwargs
):
):
"""Constructs XLMConfig.
"""Constructs XLMConfig.
"""
"""
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
emb_dim
=
emb_dim
self
.
emb_dim
=
emb_dim
self
.
n_layers
=
n_layers
self
.
n_layers
=
n_layers
...
@@ -236,9 +236,6 @@ class XLMConfig(PretrainedConfig):
...
@@ -236,9 +236,6 @@ class XLMConfig(PretrainedConfig):
if
"n_words"
in
kwargs
:
if
"n_words"
in
kwargs
:
self
.
n_words
=
kwargs
[
"n_words"
]
self
.
n_words
=
kwargs
[
"n_words"
]
self
.
bos_token_id
=
bos_token_id
self
.
pad_token_id
=
pad_token_id
@
property
@
property
def
n_words
(
self
):
# For backward compatibility
def
n_words
(
self
):
# For backward compatibility
return
self
.
vocab_size
return
self
.
vocab_size
...
...
src/transformers/configuration_xlnet.py
View file @
95e00d08
...
@@ -155,14 +155,14 @@ class XLNetConfig(PretrainedConfig):
...
@@ -155,14 +155,14 @@ class XLNetConfig(PretrainedConfig):
summary_last_dropout
=
0.1
,
summary_last_dropout
=
0.1
,
start_n_top
=
5
,
start_n_top
=
5
,
end_n_top
=
5
,
end_n_top
=
5
,
bos_token_id
=
1
,
pad_token_id
=
5
,
pad_token_id
=
5
,
bos_token_id
=
1
,
eos_token_id
=
2
,
eos_token_id
=
2
,
**
kwargs
**
kwargs
):
):
"""Constructs XLNetConfig.
"""Constructs XLNetConfig.
"""
"""
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
d_model
=
d_model
self
.
d_model
=
d_model
self
.
n_layer
=
n_layer
self
.
n_layer
=
n_layer
...
@@ -193,7 +193,7 @@ class XLNetConfig(PretrainedConfig):
...
@@ -193,7 +193,7 @@ class XLNetConfig(PretrainedConfig):
self
.
bos_token_id
=
bos_token_id
self
.
bos_token_id
=
bos_token_id
self
.
pad_token_id
=
pad_token_id
self
.
pad_token_id
=
pad_token_id
self
.
eos_token_id
s
=
[
eos_token_id
]
self
.
eos_token_id
=
eos_token_id
@
property
@
property
def
max_position_embeddings
(
self
):
def
max_position_embeddings
(
self
):
...
...
src/transformers/modeling_bart.py
View file @
95e00d08
...
@@ -906,8 +906,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
...
@@ -906,8 +906,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
def
prepare_scores_for_generation
(
self
,
scores
,
cur_len
,
max_length
):
def
prepare_scores_for_generation
(
self
,
scores
,
cur_len
,
max_length
):
if
cur_len
==
1
:
if
cur_len
==
1
:
self
.
_force_token_ids_generation
(
scores
,
self
.
config
.
bos_token_id
)
self
.
_force_token_ids_generation
(
scores
,
self
.
config
.
bos_token_id
)
if
cur_len
==
max_length
-
1
and
self
.
config
.
eos_token_id
s
[
0
]
is
not
None
:
if
cur_len
==
max_length
-
1
and
self
.
config
.
eos_token_id
is
not
None
:
self
.
_force_token_ids_generation
(
scores
,
self
.
config
.
eos_token_id
s
[
0
]
)
self
.
_force_token_ids_generation
(
scores
,
self
.
config
.
eos_token_id
)
return
scores
return
scores
@
staticmethod
@
staticmethod
...
@@ -1003,7 +1003,7 @@ class BartForSequenceClassification(PretrainedBartModel):
...
@@ -1003,7 +1003,7 @@ class BartForSequenceClassification(PretrainedBartModel):
encoder_outputs
=
encoder_outputs
,
encoder_outputs
=
encoder_outputs
,
)
)
x
=
outputs
[
0
]
# last hidden state
x
=
outputs
[
0
]
# last hidden state
eos_mask
=
input_ids
.
eq
(
self
.
config
.
eos_token_id
s
[
0
]
)
eos_mask
=
input_ids
.
eq
(
self
.
config
.
eos_token_id
)
if
len
(
torch
.
unique
(
eos_mask
.
sum
(
1
)))
>
1
:
if
len
(
torch
.
unique
(
eos_mask
.
sum
(
1
)))
>
1
:
raise
ValueError
(
"All examples must have the same number of <eos> tokens."
)
raise
ValueError
(
"All examples must have the same number of <eos> tokens."
)
sentence_representation
=
x
[
eos_mask
,
:].
view
(
x
.
size
(
0
),
-
1
,
x
.
size
(
-
1
))[:,
-
1
,
:]
sentence_representation
=
x
[
eos_mask
,
:].
view
(
x
.
size
(
0
),
-
1
,
x
.
size
(
-
1
))[:,
-
1
,
:]
...
...
src/transformers/modeling_tf_utils.py
View file @
95e00d08
...
@@ -469,7 +469,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -469,7 +469,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
repetition_penalty
=
None
,
repetition_penalty
=
None
,
bos_token_id
=
None
,
bos_token_id
=
None
,
pad_token_id
=
None
,
pad_token_id
=
None
,
eos_token_id
s
=
None
,
eos_token_id
=
None
,
length_penalty
=
None
,
length_penalty
=
None
,
no_repeat_ngram_size
=
None
,
no_repeat_ngram_size
=
None
,
num_return_sequences
=
None
,
num_return_sequences
=
None
,
...
@@ -518,13 +518,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -518,13 +518,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
bos_token_id: (`optional`) int
bos_token_id: (`optional`) int
Beginning of sentence token if no prompt is provided. Default to
0
.
Beginning of sentence token if no prompt is provided. Default to
specicic model bos_token_id or None if it does not exist
.
pad_token_id: (`optional`) int
pad_token_id: (`optional`) int
Pad token. Defaults to pad_token_id as defined in the models config.
Pad token. Defaults to pad_token_id as defined in the models config.
eos_token_ids: (`optional`) int or list of int
eos_token_ids: (`optional`) int or list of int
End of sequence token or list of tokens to stop the generation. Default to 0.
End of sequence token or list of tokens to stop the generation. Default to 0.
length_penalty: (`optional`) float
length_penalty: (`optional`) float
Exponential penalty to the length. Default to 1.
Exponential penalty to the length. Default to 1.
...
@@ -601,7 +602,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -601,7 +602,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
repetition_penalty
=
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
config
.
repetition_penalty
repetition_penalty
=
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
config
.
repetition_penalty
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
eos_token_id
s
=
eos_token_id
s
if
eos_token_id
s
is
not
None
else
self
.
config
.
eos_token_id
s
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
no_repeat_ngram_size
=
(
no_repeat_ngram_size
=
(
no_repeat_ngram_size
if
no_repeat_ngram_size
is
not
None
else
self
.
config
.
no_repeat_ngram_size
no_repeat_ngram_size
if
no_repeat_ngram_size
is
not
None
else
self
.
config
.
no_repeat_ngram_size
...
@@ -615,8 +616,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -615,8 +616,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
batch_size
=
shape_list
(
input_ids
)[
0
]
# overriden by the input batch_size
batch_size
=
shape_list
(
input_ids
)[
0
]
# overriden by the input batch_size
else
:
else
:
batch_size
=
1
batch_size
=
1
if
isinstance
(
eos_token_ids
,
int
):
eos_token_ids
=
[
eos_token_ids
]
assert
isinstance
(
max_length
,
int
)
and
max_length
>
0
,
"`max_length` should be a strictely positive integer."
assert
isinstance
(
max_length
,
int
)
and
max_length
>
0
,
"`max_length` should be a strictely positive integer."
assert
isinstance
(
min_length
,
int
)
and
min_length
>=
0
,
"`min_length` should be a positive integer."
assert
isinstance
(
min_length
,
int
)
and
min_length
>=
0
,
"`min_length` should be a positive integer."
...
@@ -633,9 +632,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -633,9 +632,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
assert
pad_token_id
is
None
or
(
assert
pad_token_id
is
None
or
(
isinstance
(
pad_token_id
,
int
)
and
(
pad_token_id
>=
0
)
isinstance
(
pad_token_id
,
int
)
and
(
pad_token_id
>=
0
)
),
"`pad_token_id` should be a positive integer."
),
"`pad_token_id` should be a positive integer."
assert
(
eos_token_id
s
is
None
)
or
(
assert
(
eos_token_id
is
None
)
or
(
isinstance
(
eos_token_id
s
,
(
list
,
tuple
))
and
((
isinstance
(
e
,
int
)
and
e
>=
0
)
for
e
in
eos_token_id
s
)
isinstance
(
eos_token_id
,
int
)
and
(
eos_token_id
>=
0
)
),
"`eos_token_id
s
` should be a positive integer
or a list/tuple of positive integers
."
),
"`eos_token_id` should be a positive integer."
assert
(
assert
(
decoder_start_token_id
is
not
None
or
self
.
config
.
is_encoder_decoder
is
False
decoder_start_token_id
is
not
None
or
self
.
config
.
is_encoder_decoder
is
False
),
"`decoder_start_token_id` has to be defined if model is encoder-decoder model"
),
"`decoder_start_token_id` has to be defined if model is encoder-decoder model"
...
@@ -674,11 +673,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -674,11 +673,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
elif
attention_mask
is
None
:
elif
attention_mask
is
None
:
attention_mask
=
tf
.
ones_like
(
input_ids
)
attention_mask
=
tf
.
ones_like
(
input_ids
)
if
pad_token_id
is
None
and
eos_token_id
s
is
not
None
:
if
pad_token_id
is
None
and
eos_token_id
is
not
None
:
logger
.
warning
(
logger
.
warning
(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence"
.
format
(
eos_token_id
s
[
0
]
)
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence"
.
format
(
eos_token_id
)
)
)
pad_token_id
=
eos_token_id
s
[
0
]
pad_token_id
=
eos_token_id
# current position and vocab size
# current position and vocab size
cur_len
=
shape_list
(
input_ids
)[
1
]
cur_len
=
shape_list
(
input_ids
)[
1
]
...
@@ -742,7 +741,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -742,7 +741,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size
=
no_repeat_ngram_size
,
no_repeat_ngram_size
=
no_repeat_ngram_size
,
bos_token_id
=
bos_token_id
,
bos_token_id
=
bos_token_id
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
pad_token_id
,
eos_token_id
s
=
eos_token_id
s
,
eos_token_id
=
eos_token_id
,
decoder_start_token_id
=
decoder_start_token_id
,
decoder_start_token_id
=
decoder_start_token_id
,
batch_size
=
effective_batch_size
,
batch_size
=
effective_batch_size
,
num_return_sequences
=
num_return_sequences
,
num_return_sequences
=
num_return_sequences
,
...
@@ -766,7 +765,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -766,7 +765,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size
=
no_repeat_ngram_size
,
no_repeat_ngram_size
=
no_repeat_ngram_size
,
bos_token_id
=
bos_token_id
,
bos_token_id
=
bos_token_id
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
pad_token_id
,
eos_token_id
s
=
eos_token_id
s
,
eos_token_id
=
eos_token_id
,
decoder_start_token_id
=
decoder_start_token_id
,
decoder_start_token_id
=
decoder_start_token_id
,
batch_size
=
effective_batch_size
,
batch_size
=
effective_batch_size
,
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
...
@@ -790,7 +789,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -790,7 +789,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size
,
no_repeat_ngram_size
,
bos_token_id
,
bos_token_id
,
pad_token_id
,
pad_token_id
,
eos_token_id
s
,
eos_token_id
,
decoder_start_token_id
,
decoder_start_token_id
,
batch_size
,
batch_size
,
vocab_size
,
vocab_size
,
...
@@ -839,10 +838,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -839,10 +838,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
)
)
# set eos token prob to zero if min_length is not reached
# set eos token prob to zero if min_length is not reached
if
eos_token_id
s
is
not
None
and
cur_len
<
min_length
:
if
eos_token_id
is
not
None
and
cur_len
<
min_length
:
# create eos_token_id
s
boolean mask
# create eos_token_id boolean mask
is_token_logit_eos_token
=
tf
.
convert_to_tensor
(
is_token_logit_eos_token
=
tf
.
convert_to_tensor
(
[
True
if
token
i
n
eos_token_id
s
else
False
for
token
in
range
(
vocab_size
)],
dtype
=
tf
.
bool
[
True
if
token
i
s
eos_token_id
else
False
for
token
in
range
(
vocab_size
)],
dtype
=
tf
.
bool
)
)
eos_token_indices_mask
=
tf
.
broadcast_to
(
is_token_logit_eos_token
,
[
batch_size
,
vocab_size
])
eos_token_indices_mask
=
tf
.
broadcast_to
(
is_token_logit_eos_token
,
[
batch_size
,
vocab_size
])
...
@@ -865,28 +864,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -865,28 +864,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_token
=
tf
.
math
.
argmax
(
next_token_logits
,
axis
=-
1
,
output_type
=
tf
.
int32
)
next_token
=
tf
.
math
.
argmax
(
next_token_logits
,
axis
=-
1
,
output_type
=
tf
.
int32
)
# update generations and finished sentences
# update generations and finished sentences
if
eos_token_id
s
is
not
None
:
if
eos_token_id
is
not
None
:
# pad finished sentences if eos_token_id
s
exist
# pad finished sentences if eos_token_id exist
tokens_to_add
=
next_token
*
unfinished_sents
+
(
pad_token_id
)
*
(
1
-
unfinished_sents
)
tokens_to_add
=
next_token
*
unfinished_sents
+
(
pad_token_id
)
*
(
1
-
unfinished_sents
)
else
:
else
:
tokens_to_add
=
next_token
tokens_to_add
=
next_token
input_ids
=
tf
.
concat
([
input_ids
,
tf
.
expand_dims
(
tokens_to_add
,
-
1
)],
1
)
input_ids
=
tf
.
concat
([
input_ids
,
tf
.
expand_dims
(
tokens_to_add
,
-
1
)],
1
)
if
eos_token_ids
is
not
None
:
if
eos_token_id
is
not
None
:
for
eos_token_id
in
eos_token_ids
:
eos_in_sents
=
tokens_to_add
==
eos_token_id
eos_in_sents
=
tokens_to_add
==
eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos
=
tf
.
math
.
multiply
(
is_sents_unfinished_and_token_to_add_is_eos
=
tf
.
math
.
multiply
(
unfinished_sents
,
tf
.
cast
(
eos_in_sents
,
tf
.
int32
)
unfinished_sents
,
tf
.
cast
(
eos_in_sents
,
tf
.
int32
)
)
)
sent_lengths
=
(
sent_lengths
=
(
sent_lengths
*
(
1
-
is_sents_unfinished_and_token_to_add_is_eos
)
sent_lengths
*
(
1
-
is_sents_unfinished_and_token_to_add_is_eos
)
+
cur_len
*
is_sents_unfinished_and_token_to_add_is_eos
+
cur_len
*
is_sents_unfinished_and_token_to_add_is_eos
)
)
# unfinished_sents is set to zero if eos in sentence
# unfinished_sents is set to zero if eos in sentence
unfinished_sents
-=
is_sents_unfinished_and_token_to_add_is_eos
unfinished_sents
-=
is_sents_unfinished_and_token_to_add_is_eos
# stop when there is a </s> in each sentence, or if we exceed the maximul length
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if
tf
.
math
.
reduce_max
(
unfinished_sents
)
==
0
:
if
tf
.
math
.
reduce_max
(
unfinished_sents
)
==
0
:
...
@@ -937,8 +935,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -937,8 +935,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size
,
no_repeat_ngram_size
,
bos_token_id
,
bos_token_id
,
pad_token_id
,
pad_token_id
,
eos_token_ids
,
decoder_start_token_id
,
decoder_start_token_id
,
eos_token_id
,
batch_size
,
batch_size
,
num_return_sequences
,
num_return_sequences
,
length_penalty
,
length_penalty
,
...
@@ -996,10 +994,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -996,10 +994,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
scores
=
tf
.
nn
.
log_softmax
(
next_token_logits
,
axis
=-
1
)
# (batch_size * num_beams, vocab_size)
# set eos token prob to zero if min_length is not reached
# set eos token prob to zero if min_length is not reached
if
eos_token_id
s
is
not
None
and
cur_len
<
min_length
:
if
eos_token_id
is
not
None
and
cur_len
<
min_length
:
# create eos_token_id
s
boolean mask
# create eos_token_id boolean mask
is_token_logit_eos_token
=
tf
.
convert_to_tensor
(
is_token_logit_eos_token
=
tf
.
convert_to_tensor
(
[
True
if
token
i
n
eos_token_id
s
else
False
for
token
in
range
(
vocab_size
)],
dtype
=
tf
.
bool
[
True
if
token
i
s
eos_token_id
else
False
for
token
in
range
(
vocab_size
)],
dtype
=
tf
.
bool
)
)
eos_token_indices_mask
=
tf
.
broadcast_to
(
is_token_logit_eos_token
,
[
batch_size
,
vocab_size
])
eos_token_indices_mask
=
tf
.
broadcast_to
(
is_token_logit_eos_token
,
[
batch_size
,
vocab_size
])
...
@@ -1072,7 +1070,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -1072,7 +1070,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
len
(
generated_hyps
[
batch_idx
])
>=
num_beams
len
(
generated_hyps
[
batch_idx
])
>=
num_beams
),
"Batch can only be done if at least {} beams have been generated"
.
format
(
num_beams
)
),
"Batch can only be done if at least {} beams have been generated"
.
format
(
num_beams
)
assert
(
assert
(
eos_token_id
s
is
not
None
and
pad_token_id
is
not
None
eos_token_id
is
not
None
and
pad_token_id
is
not
None
),
"generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
),
"generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam
.
extend
([(
0
,
pad_token_id
,
0
)]
*
num_beams
)
# pad the batch
next_batch_beam
.
extend
([(
0
,
pad_token_id
,
0
)]
*
num_beams
)
# pad the batch
continue
continue
...
@@ -1091,7 +1089,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -1091,7 +1089,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
effective_beam_id
=
batch_idx
*
num_beams
+
beam_id
effective_beam_id
=
batch_idx
*
num_beams
+
beam_id
# add to generated hypotheses if end of sentence or last iteration
# add to generated hypotheses if end of sentence or last iteration
if
eos_token_id
s
is
not
None
and
token_id
.
numpy
()
i
n
eos_token_id
s
:
if
eos_token_id
is
not
None
and
token_id
.
numpy
()
i
s
eos_token_id
:
# if beam_token does not belong to top num_beams tokens, it should not be added
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams
=
beam_token_rank
>=
num_beams
is_beam_token_worse_than_top_num_beams
=
beam_token_rank
>=
num_beams
if
is_beam_token_worse_than_top_num_beams
:
if
is_beam_token_worse_than_top_num_beams
:
...
@@ -1148,8 +1146,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -1148,8 +1146,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if
done
[
batch_idx
]:
if
done
[
batch_idx
]:
continue
continue
# test that beam scores match previously calculated scores if not eos and batch_idx not done
# test that beam scores match previously calculated scores if not eos and batch_idx not done
if
eos_token_id
s
is
not
None
and
all
(
if
eos_token_id
is
not
None
and
all
(
(
token_id
%
vocab_size
).
numpy
().
item
()
not
in
eos_token_id
s
for
token_id
in
next_tokens
[
batch_idx
]
(
token_id
%
vocab_size
).
numpy
().
item
()
is
not
eos_token_id
for
token_id
in
next_tokens
[
batch_idx
]
):
):
assert
tf
.
reduce_all
(
assert
tf
.
reduce_all
(
next_scores
[
batch_idx
,
:
num_beams
]
==
tf
.
reshape
(
beam_scores
,
(
batch_size
,
num_beams
))[
batch_idx
]
next_scores
[
batch_idx
,
:
num_beams
]
==
tf
.
reshape
(
beam_scores
,
(
batch_size
,
num_beams
))[
batch_idx
]
...
@@ -1199,7 +1197,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -1199,7 +1197,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if
sent_lengths
[
i
]
<
max_length
:
if
sent_lengths
[
i
]
<
max_length
:
decoded_hypo
=
tf
.
where
(
decoded_hypo
=
tf
.
where
(
tf
.
range
(
max_length
)
==
sent_lengths
[
i
],
tf
.
range
(
max_length
)
==
sent_lengths
[
i
],
eos_token_id
s
[
0
]
*
tf
.
ones
((
sent_max_len
,),
dtype
=
tf
.
int32
),
eos_token_id
*
tf
.
ones
((
sent_max_len
,),
dtype
=
tf
.
int32
),
decoded_hypo
,
decoded_hypo
,
)
)
decoded_list
.
append
(
decoded_hypo
)
decoded_list
.
append
(
decoded_hypo
)
...
...
src/transformers/modeling_utils.py
View file @
95e00d08
...
@@ -665,7 +665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -665,7 +665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
repetition_penalty
=
None
,
repetition_penalty
=
None
,
bos_token_id
=
None
,
bos_token_id
=
None
,
pad_token_id
=
None
,
pad_token_id
=
None
,
eos_token_id
s
=
None
,
eos_token_id
=
None
,
length_penalty
=
None
,
length_penalty
=
None
,
no_repeat_ngram_size
=
None
,
no_repeat_ngram_size
=
None
,
num_return_sequences
=
None
,
num_return_sequences
=
None
,
...
@@ -713,6 +713,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -713,6 +713,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
repetition_penalty: (`optional`) float
repetition_penalty: (`optional`) float
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
pad_token_id: (`optional`) int
Padding token. Default to specicic model pad_token_id or None if it does not exist.
bos_token_id: (`optional`) int
bos_token_id: (`optional`) int
BOS token. Defaults to bos_token_id as defined in the models config.
BOS token. Defaults to bos_token_id as defined in the models config.
...
@@ -798,7 +801,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -798,7 +801,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
repetition_penalty
=
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
config
.
repetition_penalty
repetition_penalty
=
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
config
.
repetition_penalty
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
eos_token_id
s
=
eos_token_id
s
if
eos_token_id
s
is
not
None
else
self
.
config
.
eos_token_id
s
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
no_repeat_ngram_size
=
(
no_repeat_ngram_size
=
(
no_repeat_ngram_size
if
no_repeat_ngram_size
is
not
None
else
self
.
config
.
no_repeat_ngram_size
no_repeat_ngram_size
if
no_repeat_ngram_size
is
not
None
else
self
.
config
.
no_repeat_ngram_size
...
@@ -812,8 +815,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -812,8 +815,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
batch_size
=
input_ids
.
shape
[
0
]
# overriden by the input batch_size
batch_size
=
input_ids
.
shape
[
0
]
# overriden by the input batch_size
else
:
else
:
batch_size
=
1
batch_size
=
1
if
isinstance
(
eos_token_ids
,
int
):
eos_token_ids
=
[
eos_token_ids
]
assert
isinstance
(
max_length
,
int
)
and
max_length
>
0
,
"`max_length` should be a strictly positive integer."
assert
isinstance
(
max_length
,
int
)
and
max_length
>
0
,
"`max_length` should be a strictly positive integer."
assert
isinstance
(
min_length
,
int
)
and
min_length
>=
0
,
"`min_length` should be a positive integer."
assert
isinstance
(
min_length
,
int
)
and
min_length
>=
0
,
"`min_length` should be a positive integer."
...
@@ -830,12 +831,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -830,12 +831,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert
pad_token_id
is
None
or
(
assert
pad_token_id
is
None
or
(
isinstance
(
pad_token_id
,
int
)
and
(
pad_token_id
>=
0
)
isinstance
(
pad_token_id
,
int
)
and
(
pad_token_id
>=
0
)
),
"`pad_token_id` should be a positive integer."
),
"`pad_token_id` should be a positive integer."
assert
(
eos_token_ids
is
None
)
or
(
isinstance
(
eos_token_ids
,
(
list
,
tuple
))
and
((
isinstance
(
e
,
int
)
and
e
>=
0
)
for
e
in
eos_token_ids
)
),
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert
(
assert
(
decoder_start_token_id
is
not
None
or
self
.
config
.
is_encoder_decoder
is
False
decoder_start_token_id
is
not
None
or
self
.
config
.
is_encoder_decoder
is
False
),
"`decoder_start_token_id` has to be defined if model is encoder-decoder model"
),
"`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert
(
eos_token_id
is
None
)
or
(
isinstance
(
eos_token_id
,
int
)
and
(
eos_token_id
>=
0
)
),
"`eos_token_id` should be a positive integer."
assert
length_penalty
>
0
,
"`length_penalty` should be strictly positive."
assert
length_penalty
>
0
,
"`length_penalty` should be strictly positive."
assert
(
assert
(
isinstance
(
no_repeat_ngram_size
,
int
)
and
no_repeat_ngram_size
>=
0
isinstance
(
no_repeat_ngram_size
,
int
)
and
no_repeat_ngram_size
>=
0
...
@@ -876,13 +877,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -876,13 +877,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
elif
attention_mask
is
None
:
elif
attention_mask
is
None
:
attention_mask
=
input_ids
.
new_ones
(
input_ids
.
shape
)
attention_mask
=
input_ids
.
new_ones
(
input_ids
.
shape
)
# set pad_token_id to eos_token_id
s
if not set. Important that this is done after
# set pad_token_id to eos_token_id if not set. Important that this is done after
# attention_mask is created
# attention_mask is created
if
pad_token_id
is
None
and
eos_token_id
s
is
not
None
:
if
pad_token_id
is
None
and
eos_token_id
is
not
None
:
logger
.
warning
(
logger
.
warning
(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence"
.
format
(
eos_token_id
s
[
0
]
)
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence"
.
format
(
eos_token_id
)
)
)
pad_token_id
=
eos_token_id
s
[
0
]
pad_token_id
=
eos_token_id
# current position and vocab size
# current position and vocab size
vocab_size
=
self
.
config
.
vocab_size
vocab_size
=
self
.
config
.
vocab_size
...
@@ -947,8 +948,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -947,8 +948,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size
=
no_repeat_ngram_size
,
no_repeat_ngram_size
=
no_repeat_ngram_size
,
bos_token_id
=
bos_token_id
,
bos_token_id
=
bos_token_id
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
pad_token_id
,
eos_token_ids
=
eos_token_ids
,
decoder_start_token_id
=
decoder_start_token_id
,
decoder_start_token_id
=
decoder_start_token_id
,
eos_token_id
=
eos_token_id
,
batch_size
=
effective_batch_size
,
batch_size
=
effective_batch_size
,
num_return_sequences
=
num_return_sequences
,
num_return_sequences
=
num_return_sequences
,
length_penalty
=
length_penalty
,
length_penalty
=
length_penalty
,
...
@@ -971,8 +972,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -971,8 +972,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size
=
no_repeat_ngram_size
,
no_repeat_ngram_size
=
no_repeat_ngram_size
,
bos_token_id
=
bos_token_id
,
bos_token_id
=
bos_token_id
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
pad_token_id
,
eos_token_ids
=
eos_token_ids
,
decoder_start_token_id
=
decoder_start_token_id
,
decoder_start_token_id
=
decoder_start_token_id
,
eos_token_id
=
eos_token_id
,
batch_size
=
effective_batch_size
,
batch_size
=
effective_batch_size
,
encoder_outputs
=
encoder_outputs
,
encoder_outputs
=
encoder_outputs
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
...
@@ -994,7 +995,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -994,7 +995,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size
,
no_repeat_ngram_size
,
bos_token_id
,
bos_token_id
,
pad_token_id
,
pad_token_id
,
eos_token_id
s
,
eos_token_id
,
decoder_start_token_id
,
decoder_start_token_id
,
batch_size
,
batch_size
,
encoder_outputs
,
encoder_outputs
,
...
@@ -1031,9 +1032,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1031,9 +1032,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits
[
batch_idx
,
banned_tokens
[
batch_idx
]]
=
-
float
(
"inf"
)
next_token_logits
[
batch_idx
,
banned_tokens
[
batch_idx
]]
=
-
float
(
"inf"
)
# set eos token prob to zero if min_length is not reached
# set eos token prob to zero if min_length is not reached
if
eos_token_ids
is
not
None
and
cur_len
<
min_length
:
if
eos_token_id
is
not
None
and
cur_len
<
min_length
:
for
eos_token_id
in
eos_token_ids
:
next_token_logits
[:,
eos_token_id
]
=
-
float
(
"inf"
)
next_token_logits
[:,
eos_token_id
]
=
-
float
(
"inf"
)
if
do_sample
:
if
do_sample
:
# Temperature (higher temperature => more likely to sample low probability tokens)
# Temperature (higher temperature => more likely to sample low probability tokens)
...
@@ -1049,22 +1049,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1049,22 +1049,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token
=
torch
.
argmax
(
next_token_logits
,
dim
=-
1
)
next_token
=
torch
.
argmax
(
next_token_logits
,
dim
=-
1
)
# update generations and finished sentences
# update generations and finished sentences
if
eos_token_id
s
is
not
None
:
if
eos_token_id
is
not
None
:
# pad finished sentences if eos_token_id
s
exist
# pad finished sentences if eos_token_id exist
tokens_to_add
=
next_token
*
unfinished_sents
+
(
pad_token_id
)
*
(
1
-
unfinished_sents
)
tokens_to_add
=
next_token
*
unfinished_sents
+
(
pad_token_id
)
*
(
1
-
unfinished_sents
)
else
:
else
:
tokens_to_add
=
next_token
tokens_to_add
=
next_token
input_ids
=
torch
.
cat
([
input_ids
,
tokens_to_add
.
unsqueeze
(
-
1
)],
dim
=-
1
)
input_ids
=
torch
.
cat
([
input_ids
,
tokens_to_add
.
unsqueeze
(
-
1
)],
dim
=-
1
)
if
eos_token_ids
is
not
None
:
if
eos_token_id
is
not
None
:
for
eos_token_id
in
eos_token_ids
:
eos_in_sents
=
tokens_to_add
==
eos_token_id
eos_in_sents
=
tokens_to_add
==
eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos
=
unfinished_sents
.
mul
(
eos_in_sents
.
long
()).
bool
()
is_sents_unfinished_and_token_to_add_is_eos
=
unfinished_sents
.
mul
(
eos_in_sents
.
long
()).
bool
()
sent_lengths
.
masked_fill_
(
is_sents_unfinished_and_token_to_add_is_eos
,
cur_len
+
1
)
sent_lengths
.
masked_fill_
(
is_sents_unfinished_and_token_to_add_is_eos
,
cur_len
+
1
)
# unfinished_sents is set to zero if eos in sentence
# unfinished_sents is set to zero if eos in sentence
unfinished_sents
.
mul_
((
~
eos_in_sents
).
long
())
unfinished_sents
.
mul_
((
~
eos_in_sents
).
long
())
# stop when there is a </s> in each sentence, or if we exceed the maximul length
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if
unfinished_sents
.
max
()
==
0
:
if
unfinished_sents
.
max
()
==
0
:
...
@@ -1106,7 +1105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1106,7 +1105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size
,
no_repeat_ngram_size
,
bos_token_id
,
bos_token_id
,
pad_token_id
,
pad_token_id
,
eos_token_id
s
,
eos_token_id
,
decoder_start_token_id
,
decoder_start_token_id
,
batch_size
,
batch_size
,
num_return_sequences
,
num_return_sequences
,
...
@@ -1163,9 +1162,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1163,9 +1162,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
scores
=
self
.
prepare_scores_for_generation
(
scores
,
cur_len
=
cur_len
,
max_length
=
max_length
)
scores
=
self
.
prepare_scores_for_generation
(
scores
,
cur_len
=
cur_len
,
max_length
=
max_length
)
# set eos token prob to zero if min_length is not reached
# set eos token prob to zero if min_length is not reached
if
eos_token_ids
is
not
None
and
cur_len
<
min_length
:
if
eos_token_id
is
not
None
and
cur_len
<
min_length
:
for
eos_token_id
in
eos_token_ids
:
scores
[:,
eos_token_id
]
=
-
float
(
"inf"
)
scores
[:,
eos_token_id
]
=
-
float
(
"inf"
)
if
no_repeat_ngram_size
>
0
:
if
no_repeat_ngram_size
>
0
:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
...
@@ -1225,7 +1223,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1225,7 +1223,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
len
(
generated_hyps
[
batch_idx
])
>=
num_beams
len
(
generated_hyps
[
batch_idx
])
>=
num_beams
),
"Batch can only be done if at least {} beams have been generated"
.
format
(
num_beams
)
),
"Batch can only be done if at least {} beams have been generated"
.
format
(
num_beams
)
assert
(
assert
(
eos_token_id
s
is
not
None
and
pad_token_id
is
not
None
eos_token_id
is
not
None
and
pad_token_id
is
not
None
),
"generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
),
"generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam
.
extend
([(
0
,
pad_token_id
,
0
)]
*
num_beams
)
# pad the batch
next_batch_beam
.
extend
([(
0
,
pad_token_id
,
0
)]
*
num_beams
)
# pad the batch
continue
continue
...
@@ -1244,7 +1242,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1244,7 +1242,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
effective_beam_id
=
batch_idx
*
num_beams
+
beam_id
effective_beam_id
=
batch_idx
*
num_beams
+
beam_id
# add to generated hypotheses if end of sentence
# add to generated hypotheses if end of sentence
if
(
eos_token_id
s
is
not
None
)
and
(
token_id
.
item
()
i
n
eos_token_id
s
):
if
(
eos_token_id
is
not
None
)
and
(
token_id
.
item
()
i
s
eos_token_id
):
# if beam_token does not belong to top num_beams tokens, it should not be added
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams
=
beam_token_rank
>=
num_beams
is_beam_token_worse_than_top_num_beams
=
beam_token_rank
>=
num_beams
if
is_beam_token_worse_than_top_num_beams
:
if
is_beam_token_worse_than_top_num_beams
:
...
@@ -1303,8 +1301,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1303,8 +1301,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
continue
continue
# test that beam scores match previously calculated scores if not eos and batch_idx not done
# test that beam scores match previously calculated scores if not eos and batch_idx not done
if
eos_token_id
s
is
not
None
and
all
(
if
eos_token_id
is
not
None
and
all
(
(
token_id
%
vocab_size
).
item
()
not
in
eos_token_id
s
for
token_id
in
next_tokens
[
batch_idx
]
(
token_id
%
vocab_size
).
item
()
is
not
eos_token_id
for
token_id
in
next_tokens
[
batch_idx
]
):
):
assert
torch
.
all
(
assert
torch
.
all
(
next_scores
[
batch_idx
,
:
num_beams
]
==
beam_scores
.
view
(
batch_size
,
num_beams
)[
batch_idx
]
next_scores
[
batch_idx
,
:
num_beams
]
==
beam_scores
.
view
(
batch_size
,
num_beams
)[
batch_idx
]
...
@@ -1346,7 +1344,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1346,7 +1344,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
for
i
,
hypo
in
enumerate
(
best
):
for
i
,
hypo
in
enumerate
(
best
):
decoded
[
i
,
:
sent_lengths
[
i
]]
=
hypo
decoded
[
i
,
:
sent_lengths
[
i
]]
=
hypo
if
sent_lengths
[
i
]
<
max_length
:
if
sent_lengths
[
i
]
<
max_length
:
decoded
[
i
,
sent_lengths
[
i
]]
=
eos_token_id
s
[
0
]
decoded
[
i
,
sent_lengths
[
i
]]
=
eos_token_id
else
:
else
:
# none of the hypotheses have an eos_token
# none of the hypotheses have an eos_token
assert
(
len
(
hypo
)
==
max_length
for
hypo
in
best
)
assert
(
len
(
hypo
)
==
max_length
for
hypo
in
best
)
...
...
tests/test_modeling_bart.py
View file @
95e00d08
...
@@ -61,7 +61,7 @@ class ModelTester:
...
@@ -61,7 +61,7 @@ class ModelTester:
self
.
hidden_dropout_prob
=
0.1
self
.
hidden_dropout_prob
=
0.1
self
.
attention_probs_dropout_prob
=
0.1
self
.
attention_probs_dropout_prob
=
0.1
self
.
max_position_embeddings
=
20
self
.
max_position_embeddings
=
20
self
.
eos_token_id
s
=
[
2
]
self
.
eos_token_id
=
2
self
.
pad_token_id
=
1
self
.
pad_token_id
=
1
self
.
bos_token_id
=
0
self
.
bos_token_id
=
0
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -82,7 +82,7 @@ class ModelTester:
...
@@ -82,7 +82,7 @@ class ModelTester:
dropout
=
self
.
hidden_dropout_prob
,
dropout
=
self
.
hidden_dropout_prob
,
attention_dropout
=
self
.
attention_probs_dropout_prob
,
attention_dropout
=
self
.
attention_probs_dropout_prob
,
max_position_embeddings
=
self
.
max_position_embeddings
,
max_position_embeddings
=
self
.
max_position_embeddings
,
eos_token_id
s
=
self
.
eos_token_id
s
,
eos_token_id
=
self
.
eos_token_id
,
bos_token_id
=
self
.
bos_token_id
,
bos_token_id
=
self
.
bos_token_id
,
pad_token_id
=
self
.
pad_token_id
,
pad_token_id
=
self
.
pad_token_id
,
)
)
...
@@ -214,7 +214,7 @@ class BartHeadTests(unittest.TestCase):
...
@@ -214,7 +214,7 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim
=
32
,
decoder_ffn_dim
=
32
,
max_position_embeddings
=
48
,
max_position_embeddings
=
48
,
output_past
=
output_past
,
output_past
=
output_past
,
eos_token_id
s
=
[
2
]
,
eos_token_id
=
2
,
pad_token_id
=
1
,
pad_token_id
=
1
,
bos_token_id
=
0
,
bos_token_id
=
0
,
)
)
...
@@ -274,7 +274,7 @@ class BartHeadTests(unittest.TestCase):
...
@@ -274,7 +274,7 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim
=
32
,
decoder_ffn_dim
=
32
,
max_position_embeddings
=
48
,
max_position_embeddings
=
48
,
output_past
=
True
,
output_past
=
True
,
eos_token_id
s
=
[
2
]
,
eos_token_id
=
2
,
pad_token_id
=
1
,
pad_token_id
=
1
,
bos_token_id
=
0
,
bos_token_id
=
0
,
)
)
...
@@ -483,7 +483,7 @@ class BartModelIntegrationTests(unittest.TestCase):
...
@@ -483,7 +483,7 @@ class BartModelIntegrationTests(unittest.TestCase):
no_repeat_ngram_size
=
3
,
no_repeat_ngram_size
=
3
,
do_sample
=
False
,
do_sample
=
False
,
early_stopping
=
True
,
early_stopping
=
True
,
decoder_start_token_id
=
hf
.
config
.
eos_token_id
s
[
0
]
,
decoder_start_token_id
=
hf
.
config
.
eos_token_id
,
)
)
decoded
=
[
decoded
=
[
...
...
tests/test_modeling_gpt2.py
View file @
95e00d08
...
@@ -132,7 +132,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -132,7 +132,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
# type_vocab_size=self.type_vocab_size,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
# initializer_range=self.initializer_range
bos_token_id
=
self
.
bos_token_id
,
bos_token_id
=
self
.
bos_token_id
,
eos_token_id
s
=
self
.
eos_token_id
,
eos_token_id
=
self
.
eos_token_id
,
)
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
...
...
tests/test_modeling_tf_gpt2.py
View file @
95e00d08
...
@@ -130,7 +130,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -130,7 +130,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
# type_vocab_size=self.type_vocab_size,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
# initializer_range=self.initializer_range
bos_token_id
=
self
.
bos_token_id
,
bos_token_id
=
self
.
bos_token_id
,
eos_token_id
s
=
self
.
eos_token_id
,
eos_token_id
=
self
.
eos_token_id
,
)
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment