Unverified Commit 95e00d08 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Clean special token init in modeling_....py (#3264)

* make style

* fix conflicts
parent 8becb732
...@@ -35,7 +35,7 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): ...@@ -35,7 +35,7 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
min_length=min_length + 1, # +1 from original because we start at step=1 min_length=min_length + 1, # +1 from original because we start at step=1
no_repeat_ngram_size=3, no_repeat_ngram_size=3,
early_stopping=True, early_stopping=True,
decoder_start_token_id=model.config.eos_token_ids[0], decoder_start_token_id=model.config.eos_token_id,
) )
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
for hypothesis in dec: for hypothesis in dec:
......
...@@ -223,6 +223,7 @@ if is_torch_available(): ...@@ -223,6 +223,7 @@ if is_torch_available():
BartForSequenceClassification, BartForSequenceClassification,
BartModel, BartModel,
BartForConditionalGeneration, BartForConditionalGeneration,
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
) )
from .modeling_roberta import ( from .modeling_roberta import (
RobertaForMaskedLM, RobertaForMaskedLM,
......
...@@ -124,9 +124,12 @@ class AlbertConfig(PretrainedConfig): ...@@ -124,9 +124,12 @@ class AlbertConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
classifier_dropout_prob=0.1, classifier_dropout_prob=0.1,
pad_token_id=0,
bos_token_id=2,
eos_token_id=3,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.embedding_size = embedding_size self.embedding_size = embedding_size
......
...@@ -41,9 +41,6 @@ class BartConfig(PretrainedConfig): ...@@ -41,9 +41,6 @@ class BartConfig(PretrainedConfig):
activation_dropout=0.0, activation_dropout=0.0,
activation_function="gelu", activation_function="gelu",
vocab_size=50265, vocab_size=50265,
bos_token_id=0,
pad_token_id=1,
eos_token_ids=[2],
d_model=1024, d_model=1024,
encoder_ffn_dim=4096, encoder_ffn_dim=4096,
encoder_layers=12, encoder_layers=12,
...@@ -61,6 +58,9 @@ class BartConfig(PretrainedConfig): ...@@ -61,6 +58,9 @@ class BartConfig(PretrainedConfig):
output_past=False, output_past=False,
num_labels=3, num_labels=3,
is_encoder_decoder=True, is_encoder_decoder=True,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
**common_kwargs **common_kwargs
): ):
r""" r"""
...@@ -74,7 +74,7 @@ class BartConfig(PretrainedConfig): ...@@ -74,7 +74,7 @@ class BartConfig(PretrainedConfig):
output_past=output_past, output_past=output_past,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
eos_token_ids=eos_token_ids, eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder, is_encoder_decoder=is_encoder_decoder,
**common_kwargs, **common_kwargs,
) )
......
...@@ -124,9 +124,10 @@ class BertConfig(PretrainedConfig): ...@@ -124,9 +124,10 @@ class BertConfig(PretrainedConfig):
type_vocab_size=2, type_vocab_size=2,
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
pad_token_id=0,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
......
...@@ -113,9 +113,10 @@ class DistilBertConfig(PretrainedConfig): ...@@ -113,9 +113,10 @@ class DistilBertConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
qa_dropout=0.1, qa_dropout=0.1,
seq_classif_dropout=0.2, seq_classif_dropout=0.2,
pad_token_id=0,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs, pad_token_id=pad_token_id)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.sinusoidal_pos_embds = sinusoidal_pos_embds self.sinusoidal_pos_embds = sinusoidal_pos_embds
......
...@@ -145,9 +145,9 @@ class FlaubertConfig(XLMConfig): ...@@ -145,9 +145,9 @@ class FlaubertConfig(XLMConfig):
pretrained_config_archive_map = FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "flaubert" model_type = "flaubert"
def __init__(self, layerdrop=0.0, pre_norm=False, **kwargs): def __init__(self, layerdrop=0.0, pre_norm=False, pad_token_id=2, bos_token_id=0, **kwargs):
"""Constructs FlaubertConfig. """Constructs FlaubertConfig.
""" """
super().__init__(**kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
self.layerdrop = layerdrop self.layerdrop = layerdrop
self.pre_norm = pre_norm self.pre_norm = pre_norm
...@@ -142,7 +142,7 @@ class GPT2Config(PretrainedConfig): ...@@ -142,7 +142,7 @@ class GPT2Config(PretrainedConfig):
eos_token_id=50256, eos_token_id=50256,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_ctx = n_ctx self.n_ctx = n_ctx
...@@ -163,7 +163,7 @@ class GPT2Config(PretrainedConfig): ...@@ -163,7 +163,7 @@ class GPT2Config(PretrainedConfig):
self.summary_proj_to_labels = summary_proj_to_labels self.summary_proj_to_labels = summary_proj_to_labels
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.eos_token_ids = [eos_token_id] self.eos_token_id = eos_token_id
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
......
...@@ -66,3 +66,8 @@ class RobertaConfig(BertConfig): ...@@ -66,3 +66,8 @@ class RobertaConfig(BertConfig):
""" """
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "roberta" model_type = "roberta"
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
"""Constructs FlaubertConfig.
"""
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
...@@ -77,11 +77,11 @@ class T5Config(PretrainedConfig): ...@@ -77,11 +77,11 @@ class T5Config(PretrainedConfig):
initializer_factor=1.0, initializer_factor=1.0,
is_encoder_decoder=True, is_encoder_decoder=True,
pad_token_id=0, pad_token_id=0,
eos_token_ids=[1], eos_token_id=1,
**kwargs **kwargs
): ):
super().__init__( super().__init__(
is_encoder_decoder=is_encoder_decoder, **kwargs, pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, **kwargs,
) )
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_positions = n_positions self.n_positions = n_positions
......
...@@ -152,7 +152,7 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -152,7 +152,7 @@ class TransfoXLConfig(PretrainedConfig):
eos_token_id=0, eos_token_id=0,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.cutoffs = [] self.cutoffs = []
...@@ -187,8 +187,6 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -187,8 +187,6 @@ class TransfoXLConfig(PretrainedConfig):
self.init_std = init_std self.init_std = init_std
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.eos_token_ids = [eos_token_id]
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
return self.tgt_len + self.ext_len + self.mem_len return self.tgt_len + self.ext_len + self.mem_len
......
...@@ -80,7 +80,7 @@ class PretrainedConfig(object): ...@@ -80,7 +80,7 @@ class PretrainedConfig(object):
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.bos_token_id = kwargs.pop("bos_token_id", None) self.bos_token_id = kwargs.pop("bos_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None) self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_ids = kwargs.pop("eos_token_ids", None) self.eos_token_id = kwargs.pop("eos_token_id", None)
self.length_penalty = kwargs.pop("length_penalty", 1.0) self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1) self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
......
...@@ -194,13 +194,13 @@ class XLMConfig(PretrainedConfig): ...@@ -194,13 +194,13 @@ class XLMConfig(PretrainedConfig):
end_n_top=5, end_n_top=5,
mask_token_id=0, mask_token_id=0,
lang_id=0, lang_id=0,
bos_token_id=0,
pad_token_id=2, pad_token_id=2,
bos_token_id=0,
**kwargs **kwargs
): ):
"""Constructs XLMConfig. """Constructs XLMConfig.
""" """
super().__init__(**kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.emb_dim = emb_dim self.emb_dim = emb_dim
self.n_layers = n_layers self.n_layers = n_layers
...@@ -236,9 +236,6 @@ class XLMConfig(PretrainedConfig): ...@@ -236,9 +236,6 @@ class XLMConfig(PretrainedConfig):
if "n_words" in kwargs: if "n_words" in kwargs:
self.n_words = kwargs["n_words"] self.n_words = kwargs["n_words"]
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
@property @property
def n_words(self): # For backward compatibility def n_words(self): # For backward compatibility
return self.vocab_size return self.vocab_size
......
...@@ -155,14 +155,14 @@ class XLNetConfig(PretrainedConfig): ...@@ -155,14 +155,14 @@ class XLNetConfig(PretrainedConfig):
summary_last_dropout=0.1, summary_last_dropout=0.1,
start_n_top=5, start_n_top=5,
end_n_top=5, end_n_top=5,
bos_token_id=1,
pad_token_id=5, pad_token_id=5,
bos_token_id=1,
eos_token_id=2, eos_token_id=2,
**kwargs **kwargs
): ):
"""Constructs XLNetConfig. """Constructs XLNetConfig.
""" """
super().__init__(**kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.d_model = d_model self.d_model = d_model
self.n_layer = n_layer self.n_layer = n_layer
...@@ -193,7 +193,7 @@ class XLNetConfig(PretrainedConfig): ...@@ -193,7 +193,7 @@ class XLNetConfig(PretrainedConfig):
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.eos_token_ids = [eos_token_id] self.eos_token_id = eos_token_id
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
......
...@@ -906,8 +906,8 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -906,8 +906,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
def prepare_scores_for_generation(self, scores, cur_len, max_length): def prepare_scores_for_generation(self, scores, cur_len, max_length):
if cur_len == 1: if cur_len == 1:
self._force_token_ids_generation(scores, self.config.bos_token_id) self._force_token_ids_generation(scores, self.config.bos_token_id)
if cur_len == max_length - 1 and self.config.eos_token_ids[0] is not None: if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(scores, self.config.eos_token_ids[0]) self._force_token_ids_generation(scores, self.config.eos_token_id)
return scores return scores
@staticmethod @staticmethod
...@@ -1003,7 +1003,7 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1003,7 +1003,7 @@ class BartForSequenceClassification(PretrainedBartModel):
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
) )
x = outputs[0] # last hidden state x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_ids[0]) eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1: if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.") raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
......
...@@ -469,7 +469,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -469,7 +469,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
repetition_penalty=None, repetition_penalty=None,
bos_token_id=None, bos_token_id=None,
pad_token_id=None, pad_token_id=None,
eos_token_ids=None, eos_token_id=None,
length_penalty=None, length_penalty=None,
no_repeat_ngram_size=None, no_repeat_ngram_size=None,
num_return_sequences=None, num_return_sequences=None,
...@@ -518,13 +518,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -518,13 +518,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0. The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
bos_token_id: (`optional`) int bos_token_id: (`optional`) int
Beginning of sentence token if no prompt is provided. Default to 0. Beginning of sentence token if no prompt is provided. Default to specicic model bos_token_id or None if it does not exist.
pad_token_id: (`optional`) int pad_token_id: (`optional`) int
Pad token. Defaults to pad_token_id as defined in the models config. Pad token. Defaults to pad_token_id as defined in the models config.
eos_token_ids: (`optional`) int or list of int eos_token_ids: (`optional`) int or list of int
End of sequence token or list of tokens to stop the generation. Default to 0. End of sequence token or list of tokens to stop the generation. Default to 0.
length_penalty: (`optional`) float length_penalty: (`optional`) float
Exponential penalty to the length. Default to 1. Exponential penalty to the length. Default to 1.
...@@ -601,7 +602,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -601,7 +602,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = ( no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
...@@ -615,8 +616,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -615,8 +616,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
else: else:
batch_size = 1 batch_size = 1
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer." assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
...@@ -633,9 +632,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -633,9 +632,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
assert pad_token_id is None or ( assert pad_token_id is None or (
isinstance(pad_token_id, int) and (pad_token_id >= 0) isinstance(pad_token_id, int) and (pad_token_id >= 0)
), "`pad_token_id` should be a positive integer." ), "`pad_token_id` should be a positive integer."
assert (eos_token_ids is None) or ( assert (eos_token_id is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids) isinstance(eos_token_id, int) and (eos_token_id >= 0)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." ), "`eos_token_id` should be a positive integer."
assert ( assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model" ), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
...@@ -674,11 +673,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -674,11 +673,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
elif attention_mask is None: elif attention_mask is None:
attention_mask = tf.ones_like(input_ids) attention_mask = tf.ones_like(input_ids)
if pad_token_id is None and eos_token_ids is not None: if pad_token_id is None and eos_token_id is not None:
logger.warning( logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0]) "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
) )
pad_token_id = eos_token_ids[0] pad_token_id = eos_token_id
# current position and vocab size # current position and vocab size
cur_len = shape_list(input_ids)[1] cur_len = shape_list(input_ids)[1]
...@@ -742,7 +741,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -742,7 +741,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids, eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size, batch_size=effective_batch_size,
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
...@@ -766,7 +765,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -766,7 +765,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids, eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size, batch_size=effective_batch_size,
vocab_size=vocab_size, vocab_size=vocab_size,
...@@ -790,7 +789,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -790,7 +789,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size, no_repeat_ngram_size,
bos_token_id, bos_token_id,
pad_token_id, pad_token_id,
eos_token_ids, eos_token_id,
decoder_start_token_id, decoder_start_token_id,
batch_size, batch_size,
vocab_size, vocab_size,
...@@ -839,10 +838,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -839,10 +838,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
) )
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length: if eos_token_id is not None and cur_len < min_length:
# create eos_token_ids boolean mask # create eos_token_id boolean mask
is_token_logit_eos_token = tf.convert_to_tensor( is_token_logit_eos_token = tf.convert_to_tensor(
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
) )
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size]) eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
...@@ -865,28 +864,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -865,28 +864,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32) next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
# update generations and finished sentences # update generations and finished sentences
if eos_token_ids is not None: if eos_token_id is not None:
# pad finished sentences if eos_token_ids exist # pad finished sentences if eos_token_id exist
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
else: else:
tokens_to_add = next_token tokens_to_add = next_token
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1) input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
if eos_token_ids is not None: if eos_token_id is not None:
for eos_token_id in eos_token_ids: eos_in_sents = tokens_to_add == eos_token_id
eos_in_sents = tokens_to_add == eos_token_id # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply( unfinished_sents, tf.cast(eos_in_sents, tf.int32)
unfinished_sents, tf.cast(eos_in_sents, tf.int32) )
) sent_lengths = (
sent_lengths = ( sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos) + cur_len * is_sents_unfinished_and_token_to_add_is_eos
+ cur_len * is_sents_unfinished_and_token_to_add_is_eos )
)
# unfinished_sents is set to zero if eos in sentence # unfinished_sents is set to zero if eos in sentence
unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
# stop when there is a </s> in each sentence, or if we exceed the maximul length # stop when there is a </s> in each sentence, or if we exceed the maximul length
if tf.math.reduce_max(unfinished_sents) == 0: if tf.math.reduce_max(unfinished_sents) == 0:
...@@ -937,8 +935,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -937,8 +935,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size, no_repeat_ngram_size,
bos_token_id, bos_token_id,
pad_token_id, pad_token_id,
eos_token_ids,
decoder_start_token_id, decoder_start_token_id,
eos_token_id,
batch_size, batch_size,
num_return_sequences, num_return_sequences,
length_penalty, length_penalty,
...@@ -996,10 +994,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -996,10 +994,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length: if eos_token_id is not None and cur_len < min_length:
# create eos_token_ids boolean mask # create eos_token_id boolean mask
is_token_logit_eos_token = tf.convert_to_tensor( is_token_logit_eos_token = tf.convert_to_tensor(
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
) )
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size]) eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
...@@ -1072,7 +1070,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1072,7 +1070,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
len(generated_hyps[batch_idx]) >= num_beams len(generated_hyps[batch_idx]) >= num_beams
), "Batch can only be done if at least {} beams have been generated".format(num_beams) ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
assert ( assert (
eos_token_ids is not None and pad_token_id is not None eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue continue
...@@ -1091,7 +1089,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1091,7 +1089,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
effective_beam_id = batch_idx * num_beams + beam_id effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence or last iteration # add to generated hypotheses if end of sentence or last iteration
if eos_token_ids is not None and token_id.numpy() in eos_token_ids: if eos_token_id is not None and token_id.numpy() is eos_token_id:
# if beam_token does not belong to top num_beams tokens, it should not be added # if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams: if is_beam_token_worse_than_top_num_beams:
...@@ -1148,8 +1146,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1148,8 +1146,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if done[batch_idx]: if done[batch_idx]:
continue continue
# test that beam scores match previously calculated scores if not eos and batch_idx not done # test that beam scores match previously calculated scores if not eos and batch_idx not done
if eos_token_ids is not None and all( if eos_token_id is not None and all(
(token_id % vocab_size).numpy().item() not in eos_token_ids for token_id in next_tokens[batch_idx] (token_id % vocab_size).numpy().item() is not eos_token_id for token_id in next_tokens[batch_idx]
): ):
assert tf.reduce_all( assert tf.reduce_all(
next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx] next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
...@@ -1199,7 +1197,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1199,7 +1197,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if sent_lengths[i] < max_length: if sent_lengths[i] < max_length:
decoded_hypo = tf.where( decoded_hypo = tf.where(
tf.range(max_length) == sent_lengths[i], tf.range(max_length) == sent_lengths[i],
eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32), eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32),
decoded_hypo, decoded_hypo,
) )
decoded_list.append(decoded_hypo) decoded_list.append(decoded_hypo)
......
...@@ -665,7 +665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -665,7 +665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
repetition_penalty=None, repetition_penalty=None,
bos_token_id=None, bos_token_id=None,
pad_token_id=None, pad_token_id=None,
eos_token_ids=None, eos_token_id=None,
length_penalty=None, length_penalty=None,
no_repeat_ngram_size=None, no_repeat_ngram_size=None,
num_return_sequences=None, num_return_sequences=None,
...@@ -713,6 +713,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -713,6 +713,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
repetition_penalty: (`optional`) float repetition_penalty: (`optional`) float
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0. The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
pad_token_id: (`optional`) int
Padding token. Default to specicic model pad_token_id or None if it does not exist.
bos_token_id: (`optional`) int bos_token_id: (`optional`) int
BOS token. Defaults to bos_token_id as defined in the models config. BOS token. Defaults to bos_token_id as defined in the models config.
...@@ -798,7 +801,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -798,7 +801,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = ( no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
...@@ -812,8 +815,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -812,8 +815,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
batch_size = input_ids.shape[0] # overriden by the input batch_size batch_size = input_ids.shape[0] # overriden by the input batch_size
else: else:
batch_size = 1 batch_size = 1
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
...@@ -830,12 +831,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -830,12 +831,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert pad_token_id is None or ( assert pad_token_id is None or (
isinstance(pad_token_id, int) and (pad_token_id >= 0) isinstance(pad_token_id, int) and (pad_token_id >= 0)
), "`pad_token_id` should be a positive integer." ), "`pad_token_id` should be a positive integer."
assert (eos_token_ids is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert ( assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model" ), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert (eos_token_id is None) or (
isinstance(eos_token_id, int) and (eos_token_id >= 0)
), "`eos_token_id` should be a positive integer."
assert length_penalty > 0, "`length_penalty` should be strictly positive." assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert ( assert (
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
...@@ -876,13 +877,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -876,13 +877,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
elif attention_mask is None: elif attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape) attention_mask = input_ids.new_ones(input_ids.shape)
# set pad_token_id to eos_token_ids if not set. Important that this is done after # set pad_token_id to eos_token_id if not set. Important that this is done after
# attention_mask is created # attention_mask is created
if pad_token_id is None and eos_token_ids is not None: if pad_token_id is None and eos_token_id is not None:
logger.warning( logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0]) "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
) )
pad_token_id = eos_token_ids[0] pad_token_id = eos_token_id
# current position and vocab size # current position and vocab size
vocab_size = self.config.vocab_size vocab_size = self.config.vocab_size
...@@ -947,8 +948,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -947,8 +948,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size, batch_size=effective_batch_size,
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
length_penalty=length_penalty, length_penalty=length_penalty,
...@@ -971,8 +972,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -971,8 +972,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size, batch_size=effective_batch_size,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -994,7 +995,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -994,7 +995,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size, no_repeat_ngram_size,
bos_token_id, bos_token_id,
pad_token_id, pad_token_id,
eos_token_ids, eos_token_id,
decoder_start_token_id, decoder_start_token_id,
batch_size, batch_size,
encoder_outputs, encoder_outputs,
...@@ -1031,9 +1032,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1031,9 +1032,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf") next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length: if eos_token_id is not None and cur_len < min_length:
for eos_token_id in eos_token_ids: next_token_logits[:, eos_token_id] = -float("inf")
next_token_logits[:, eos_token_id] = -float("inf")
if do_sample: if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens) # Temperature (higher temperature => more likely to sample low probability tokens)
...@@ -1049,22 +1049,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1049,22 +1049,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token = torch.argmax(next_token_logits, dim=-1) next_token = torch.argmax(next_token_logits, dim=-1)
# update generations and finished sentences # update generations and finished sentences
if eos_token_ids is not None: if eos_token_id is not None:
# pad finished sentences if eos_token_ids exist # pad finished sentences if eos_token_id exist
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
else: else:
tokens_to_add = next_token tokens_to_add = next_token
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
if eos_token_ids is not None: if eos_token_id is not None:
for eos_token_id in eos_token_ids: eos_in_sents = tokens_to_add == eos_token_id
eos_in_sents = tokens_to_add == eos_token_id # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1) # unfinished_sents is set to zero if eos in sentence
# unfinished_sents is set to zero if eos in sentence unfinished_sents.mul_((~eos_in_sents).long())
unfinished_sents.mul_((~eos_in_sents).long())
# stop when there is a </s> in each sentence, or if we exceed the maximul length # stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sents.max() == 0: if unfinished_sents.max() == 0:
...@@ -1106,7 +1105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1106,7 +1105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size, no_repeat_ngram_size,
bos_token_id, bos_token_id,
pad_token_id, pad_token_id,
eos_token_ids, eos_token_id,
decoder_start_token_id, decoder_start_token_id,
batch_size, batch_size,
num_return_sequences, num_return_sequences,
...@@ -1163,9 +1162,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1163,9 +1162,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length) scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length: if eos_token_id is not None and cur_len < min_length:
for eos_token_id in eos_token_ids: scores[:, eos_token_id] = -float("inf")
scores[:, eos_token_id] = -float("inf")
if no_repeat_ngram_size > 0: if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams # calculate a list of banned tokens to prevent repetitively generating the same ngrams
...@@ -1225,7 +1223,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1225,7 +1223,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
len(generated_hyps[batch_idx]) >= num_beams len(generated_hyps[batch_idx]) >= num_beams
), "Batch can only be done if at least {} beams have been generated".format(num_beams) ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
assert ( assert (
eos_token_ids is not None and pad_token_id is not None eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue continue
...@@ -1244,7 +1242,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1244,7 +1242,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
effective_beam_id = batch_idx * num_beams + beam_id effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence # add to generated hypotheses if end of sentence
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids): if (eos_token_id is not None) and (token_id.item() is eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added # if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams: if is_beam_token_worse_than_top_num_beams:
...@@ -1303,8 +1301,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1303,8 +1301,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
continue continue
# test that beam scores match previously calculated scores if not eos and batch_idx not done # test that beam scores match previously calculated scores if not eos and batch_idx not done
if eos_token_ids is not None and all( if eos_token_id is not None and all(
(token_id % vocab_size).item() not in eos_token_ids for token_id in next_tokens[batch_idx] (token_id % vocab_size).item() is not eos_token_id for token_id in next_tokens[batch_idx]
): ):
assert torch.all( assert torch.all(
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx] next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
...@@ -1346,7 +1344,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1346,7 +1344,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
for i, hypo in enumerate(best): for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length: if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_ids[0] decoded[i, sent_lengths[i]] = eos_token_id
else: else:
# none of the hypotheses have an eos_token # none of the hypotheses have an eos_token
assert (len(hypo) == max_length for hypo in best) assert (len(hypo) == max_length for hypo in best)
......
...@@ -61,7 +61,7 @@ class ModelTester: ...@@ -61,7 +61,7 @@ class ModelTester:
self.hidden_dropout_prob = 0.1 self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1 self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 20 self.max_position_embeddings = 20
self.eos_token_ids = [2] self.eos_token_id = 2
self.pad_token_id = 1 self.pad_token_id = 1
self.bos_token_id = 0 self.bos_token_id = 0
torch.manual_seed(0) torch.manual_seed(0)
...@@ -82,7 +82,7 @@ class ModelTester: ...@@ -82,7 +82,7 @@ class ModelTester:
dropout=self.hidden_dropout_prob, dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob, attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
eos_token_ids=self.eos_token_ids, eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
) )
...@@ -214,7 +214,7 @@ class BartHeadTests(unittest.TestCase): ...@@ -214,7 +214,7 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32, decoder_ffn_dim=32,
max_position_embeddings=48, max_position_embeddings=48,
output_past=output_past, output_past=output_past,
eos_token_ids=[2], eos_token_id=2,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
) )
...@@ -274,7 +274,7 @@ class BartHeadTests(unittest.TestCase): ...@@ -274,7 +274,7 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32, decoder_ffn_dim=32,
max_position_embeddings=48, max_position_embeddings=48,
output_past=True, output_past=True,
eos_token_ids=[2], eos_token_id=2,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
) )
...@@ -483,7 +483,7 @@ class BartModelIntegrationTests(unittest.TestCase): ...@@ -483,7 +483,7 @@ class BartModelIntegrationTests(unittest.TestCase):
no_repeat_ngram_size=3, no_repeat_ngram_size=3,
do_sample=False, do_sample=False,
early_stopping=True, early_stopping=True,
decoder_start_token_id=hf.config.eos_token_ids[0], decoder_start_token_id=hf.config.eos_token_id,
) )
decoded = [ decoded = [
......
...@@ -132,7 +132,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -132,7 +132,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range # initializer_range=self.initializer_range
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_ids=self.eos_token_id, eos_token_id=self.eos_token_id,
) )
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
......
...@@ -130,7 +130,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -130,7 +130,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range # initializer_range=self.initializer_range
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_ids=self.eos_token_id, eos_token_id=self.eos_token_id,
) )
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment