Unverified Commit 95e00d08 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Clean special token init in modeling_....py (#3264)

* make style

* fix conflicts
parent 8becb732
......@@ -35,7 +35,7 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
min_length=min_length + 1, # +1 from original because we start at step=1
no_repeat_ngram_size=3,
early_stopping=True,
decoder_start_token_id=model.config.eos_token_ids[0],
decoder_start_token_id=model.config.eos_token_id,
)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
for hypothesis in dec:
......
......@@ -223,6 +223,7 @@ if is_torch_available():
BartForSequenceClassification,
BartModel,
BartForConditionalGeneration,
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_roberta import (
RobertaForMaskedLM,
......
......@@ -124,9 +124,12 @@ class AlbertConfig(PretrainedConfig):
initializer_range=0.02,
layer_norm_eps=1e-12,
classifier_dropout_prob=0.1,
pad_token_id=0,
bos_token_id=2,
eos_token_id=3,
**kwargs
):
super().__init__(**kwargs)
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.embedding_size = embedding_size
......
......@@ -41,9 +41,6 @@ class BartConfig(PretrainedConfig):
activation_dropout=0.0,
activation_function="gelu",
vocab_size=50265,
bos_token_id=0,
pad_token_id=1,
eos_token_ids=[2],
d_model=1024,
encoder_ffn_dim=4096,
encoder_layers=12,
......@@ -61,6 +58,9 @@ class BartConfig(PretrainedConfig):
output_past=False,
num_labels=3,
is_encoder_decoder=True,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
**common_kwargs
):
r"""
......@@ -74,7 +74,7 @@ class BartConfig(PretrainedConfig):
output_past=output_past,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_ids=eos_token_ids,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
**common_kwargs,
)
......
......@@ -124,9 +124,10 @@ class BertConfig(PretrainedConfig):
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
**kwargs
):
super().__init__(**kwargs)
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
......
......@@ -113,9 +113,10 @@ class DistilBertConfig(PretrainedConfig):
initializer_range=0.02,
qa_dropout=0.1,
seq_classif_dropout=0.2,
pad_token_id=0,
**kwargs
):
super().__init__(**kwargs)
super().__init__(**kwargs, pad_token_id=pad_token_id)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.sinusoidal_pos_embds = sinusoidal_pos_embds
......
......@@ -145,9 +145,9 @@ class FlaubertConfig(XLMConfig):
pretrained_config_archive_map = FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "flaubert"
def __init__(self, layerdrop=0.0, pre_norm=False, **kwargs):
def __init__(self, layerdrop=0.0, pre_norm=False, pad_token_id=2, bos_token_id=0, **kwargs):
"""Constructs FlaubertConfig.
"""
super().__init__(**kwargs)
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
self.layerdrop = layerdrop
self.pre_norm = pre_norm
......@@ -142,7 +142,7 @@ class GPT2Config(PretrainedConfig):
eos_token_id=50256,
**kwargs
):
super().__init__(**kwargs)
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.n_ctx = n_ctx
......@@ -163,7 +163,7 @@ class GPT2Config(PretrainedConfig):
self.summary_proj_to_labels = summary_proj_to_labels
self.bos_token_id = bos_token_id
self.eos_token_ids = [eos_token_id]
self.eos_token_id = eos_token_id
@property
def max_position_embeddings(self):
......
......@@ -66,3 +66,8 @@ class RobertaConfig(BertConfig):
"""
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "roberta"
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
"""Constructs FlaubertConfig.
"""
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
......@@ -77,11 +77,11 @@ class T5Config(PretrainedConfig):
initializer_factor=1.0,
is_encoder_decoder=True,
pad_token_id=0,
eos_token_ids=[1],
eos_token_id=1,
**kwargs
):
super().__init__(
is_encoder_decoder=is_encoder_decoder, **kwargs,
pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, **kwargs,
)
self.vocab_size = vocab_size
self.n_positions = n_positions
......
......@@ -152,7 +152,7 @@ class TransfoXLConfig(PretrainedConfig):
eos_token_id=0,
**kwargs
):
super().__init__(**kwargs)
super().__init__(eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.cutoffs = []
......@@ -187,8 +187,6 @@ class TransfoXLConfig(PretrainedConfig):
self.init_std = init_std
self.layer_norm_epsilon = layer_norm_epsilon
self.eos_token_ids = [eos_token_id]
@property
def max_position_embeddings(self):
return self.tgt_len + self.ext_len + self.mem_len
......
......@@ -80,7 +80,7 @@ class PretrainedConfig(object):
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_ids = kwargs.pop("eos_token_ids", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
......
......@@ -194,13 +194,13 @@ class XLMConfig(PretrainedConfig):
end_n_top=5,
mask_token_id=0,
lang_id=0,
bos_token_id=0,
pad_token_id=2,
bos_token_id=0,
**kwargs
):
"""Constructs XLMConfig.
"""
super().__init__(**kwargs)
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
self.vocab_size = vocab_size
self.emb_dim = emb_dim
self.n_layers = n_layers
......@@ -236,9 +236,6 @@ class XLMConfig(PretrainedConfig):
if "n_words" in kwargs:
self.n_words = kwargs["n_words"]
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
@property
def n_words(self): # For backward compatibility
return self.vocab_size
......
......@@ -155,14 +155,14 @@ class XLNetConfig(PretrainedConfig):
summary_last_dropout=0.1,
start_n_top=5,
end_n_top=5,
bos_token_id=1,
pad_token_id=5,
bos_token_id=1,
eos_token_id=2,
**kwargs
):
"""Constructs XLNetConfig.
"""
super().__init__(**kwargs)
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layer = n_layer
......@@ -193,7 +193,7 @@ class XLNetConfig(PretrainedConfig):
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.eos_token_ids = [eos_token_id]
self.eos_token_id = eos_token_id
@property
def max_position_embeddings(self):
......
......@@ -906,8 +906,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
def prepare_scores_for_generation(self, scores, cur_len, max_length):
if cur_len == 1:
self._force_token_ids_generation(scores, self.config.bos_token_id)
if cur_len == max_length - 1 and self.config.eos_token_ids[0] is not None:
self._force_token_ids_generation(scores, self.config.eos_token_ids[0])
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(scores, self.config.eos_token_id)
return scores
@staticmethod
......@@ -1003,7 +1003,7 @@ class BartForSequenceClassification(PretrainedBartModel):
encoder_outputs=encoder_outputs,
)
x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_ids[0])
eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
......
......@@ -469,7 +469,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
repetition_penalty=None,
bos_token_id=None,
pad_token_id=None,
eos_token_ids=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
......@@ -518,13 +518,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
bos_token_id: (`optional`) int
Beginning of sentence token if no prompt is provided. Default to 0.
Beginning of sentence token if no prompt is provided. Default to specicic model bos_token_id or None if it does not exist.
pad_token_id: (`optional`) int
Pad token. Defaults to pad_token_id as defined in the models config.
eos_token_ids: (`optional`) int or list of int
End of sequence token or list of tokens to stop the generation. Default to 0.
length_penalty: (`optional`) float
Exponential penalty to the length. Default to 1.
......@@ -601,7 +602,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
......@@ -615,8 +616,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
else:
batch_size = 1
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
......@@ -633,9 +632,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
assert pad_token_id is None or (
isinstance(pad_token_id, int) and (pad_token_id >= 0)
), "`pad_token_id` should be a positive integer."
assert (eos_token_ids is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert (eos_token_id is None) or (
isinstance(eos_token_id, int) and (eos_token_id >= 0)
), "`eos_token_id` should be a positive integer."
assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
......@@ -674,11 +673,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
elif attention_mask is None:
attention_mask = tf.ones_like(input_ids)
if pad_token_id is None and eos_token_ids is not None:
if pad_token_id is None and eos_token_id is not None:
logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
)
pad_token_id = eos_token_ids[0]
pad_token_id = eos_token_id
# current position and vocab size
cur_len = shape_list(input_ids)[1]
......@@ -742,7 +741,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
......@@ -766,7 +765,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size,
vocab_size=vocab_size,
......@@ -790,7 +789,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size,
bos_token_id,
pad_token_id,
eos_token_ids,
eos_token_id,
decoder_start_token_id,
batch_size,
vocab_size,
......@@ -839,10 +838,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
# create eos_token_ids boolean mask
if eos_token_id is not None and cur_len < min_length:
# create eos_token_id boolean mask
is_token_logit_eos_token = tf.convert_to_tensor(
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
......@@ -865,28 +864,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
# update generations and finished sentences
if eos_token_ids is not None:
# pad finished sentences if eos_token_ids exist
if eos_token_id is not None:
# pad finished sentences if eos_token_id exist
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
else:
tokens_to_add = next_token
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
if eos_token_ids is not None:
for eos_token_id in eos_token_ids:
eos_in_sents = tokens_to_add == eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
unfinished_sents, tf.cast(eos_in_sents, tf.int32)
)
sent_lengths = (
sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
+ cur_len * is_sents_unfinished_and_token_to_add_is_eos
)
if eos_token_id is not None:
eos_in_sents = tokens_to_add == eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
unfinished_sents, tf.cast(eos_in_sents, tf.int32)
)
sent_lengths = (
sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
+ cur_len * is_sents_unfinished_and_token_to_add_is_eos
)
# unfinished_sents is set to zero if eos in sentence
unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
# unfinished_sents is set to zero if eos in sentence
unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if tf.math.reduce_max(unfinished_sents) == 0:
......@@ -937,8 +935,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size,
bos_token_id,
pad_token_id,
eos_token_ids,
decoder_start_token_id,
eos_token_id,
batch_size,
num_return_sequences,
length_penalty,
......@@ -996,10 +994,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
# create eos_token_ids boolean mask
if eos_token_id is not None and cur_len < min_length:
# create eos_token_id boolean mask
is_token_logit_eos_token = tf.convert_to_tensor(
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
......@@ -1072,7 +1070,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
len(generated_hyps[batch_idx]) >= num_beams
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
assert (
eos_token_ids is not None and pad_token_id is not None
eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
......@@ -1091,7 +1089,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence or last iteration
if eos_token_ids is not None and token_id.numpy() in eos_token_ids:
if eos_token_id is not None and token_id.numpy() is eos_token_id:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
......@@ -1148,8 +1146,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if done[batch_idx]:
continue
# test that beam scores match previously calculated scores if not eos and batch_idx not done
if eos_token_ids is not None and all(
(token_id % vocab_size).numpy().item() not in eos_token_ids for token_id in next_tokens[batch_idx]
if eos_token_id is not None and all(
(token_id % vocab_size).numpy().item() is not eos_token_id for token_id in next_tokens[batch_idx]
):
assert tf.reduce_all(
next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
......@@ -1199,7 +1197,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if sent_lengths[i] < max_length:
decoded_hypo = tf.where(
tf.range(max_length) == sent_lengths[i],
eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32),
eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32),
decoded_hypo,
)
decoded_list.append(decoded_hypo)
......
......@@ -665,7 +665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
repetition_penalty=None,
bos_token_id=None,
pad_token_id=None,
eos_token_ids=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
......@@ -713,6 +713,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
repetition_penalty: (`optional`) float
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
pad_token_id: (`optional`) int
Padding token. Default to specicic model pad_token_id or None if it does not exist.
bos_token_id: (`optional`) int
BOS token. Defaults to bos_token_id as defined in the models config.
......@@ -798,7 +801,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
......@@ -812,8 +815,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
batch_size = input_ids.shape[0] # overriden by the input batch_size
else:
batch_size = 1
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
......@@ -830,12 +831,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert pad_token_id is None or (
isinstance(pad_token_id, int) and (pad_token_id >= 0)
), "`pad_token_id` should be a positive integer."
assert (eos_token_ids is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert (eos_token_id is None) or (
isinstance(eos_token_id, int) and (eos_token_id >= 0)
), "`eos_token_id` should be a positive integer."
assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert (
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
......@@ -876,13 +877,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
elif attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
# set pad_token_id to eos_token_ids if not set. Important that this is done after
# set pad_token_id to eos_token_id if not set. Important that this is done after
# attention_mask is created
if pad_token_id is None and eos_token_ids is not None:
if pad_token_id is None and eos_token_id is not None:
logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
)
pad_token_id = eos_token_ids[0]
pad_token_id = eos_token_id
# current position and vocab size
vocab_size = self.config.vocab_size
......@@ -947,8 +948,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
decoder_start_token_id=decoder_start_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
......@@ -971,8 +972,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
decoder_start_token_id=decoder_start_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
......@@ -994,7 +995,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size,
bos_token_id,
pad_token_id,
eos_token_ids,
eos_token_id,
decoder_start_token_id,
batch_size,
encoder_outputs,
......@@ -1031,9 +1032,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
for eos_token_id in eos_token_ids:
next_token_logits[:, eos_token_id] = -float("inf")
if eos_token_id is not None and cur_len < min_length:
next_token_logits[:, eos_token_id] = -float("inf")
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
......@@ -1049,22 +1049,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token = torch.argmax(next_token_logits, dim=-1)
# update generations and finished sentences
if eos_token_ids is not None:
# pad finished sentences if eos_token_ids exist
if eos_token_id is not None:
# pad finished sentences if eos_token_id exist
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
else:
tokens_to_add = next_token
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
if eos_token_ids is not None:
for eos_token_id in eos_token_ids:
eos_in_sents = tokens_to_add == eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
# unfinished_sents is set to zero if eos in sentence
unfinished_sents.mul_((~eos_in_sents).long())
if eos_token_id is not None:
eos_in_sents = tokens_to_add == eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
# unfinished_sents is set to zero if eos in sentence
unfinished_sents.mul_((~eos_in_sents).long())
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sents.max() == 0:
......@@ -1106,7 +1105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size,
bos_token_id,
pad_token_id,
eos_token_ids,
eos_token_id,
decoder_start_token_id,
batch_size,
num_return_sequences,
......@@ -1163,9 +1162,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
for eos_token_id in eos_token_ids:
scores[:, eos_token_id] = -float("inf")
if eos_token_id is not None and cur_len < min_length:
scores[:, eos_token_id] = -float("inf")
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
......@@ -1225,7 +1223,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
len(generated_hyps[batch_idx]) >= num_beams
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
assert (
eos_token_ids is not None and pad_token_id is not None
eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
......@@ -1244,7 +1242,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids):
if (eos_token_id is not None) and (token_id.item() is eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
......@@ -1303,8 +1301,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
continue
# test that beam scores match previously calculated scores if not eos and batch_idx not done
if eos_token_ids is not None and all(
(token_id % vocab_size).item() not in eos_token_ids for token_id in next_tokens[batch_idx]
if eos_token_id is not None and all(
(token_id % vocab_size).item() is not eos_token_id for token_id in next_tokens[batch_idx]
):
assert torch.all(
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
......@@ -1346,7 +1344,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_ids[0]
decoded[i, sent_lengths[i]] = eos_token_id
else:
# none of the hypotheses have an eos_token
assert (len(hypo) == max_length for hypo in best)
......
......@@ -61,7 +61,7 @@ class ModelTester:
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 20
self.eos_token_ids = [2]
self.eos_token_id = 2
self.pad_token_id = 1
self.bos_token_id = 0
torch.manual_seed(0)
......@@ -82,7 +82,7 @@ class ModelTester:
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
eos_token_ids=self.eos_token_ids,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
)
......@@ -214,7 +214,7 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32,
max_position_embeddings=48,
output_past=output_past,
eos_token_ids=[2],
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
)
......@@ -274,7 +274,7 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32,
max_position_embeddings=48,
output_past=True,
eos_token_ids=[2],
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
)
......@@ -483,7 +483,7 @@ class BartModelIntegrationTests(unittest.TestCase):
no_repeat_ngram_size=3,
do_sample=False,
early_stopping=True,
decoder_start_token_id=hf.config.eos_token_ids[0],
decoder_start_token_id=hf.config.eos_token_id,
)
decoded = [
......
......@@ -132,7 +132,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
bos_token_id=self.bos_token_id,
eos_token_ids=self.eos_token_id,
eos_token_id=self.eos_token_id,
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
......
......@@ -130,7 +130,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
bos_token_id=self.bos_token_id,
eos_token_ids=self.eos_token_id,
eos_token_id=self.eos_token_id,
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment