Unverified Commit c130e67d authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

remove adjust_logits_during_generation method (#10087)

* add forced logits processors

* delete adjust_logits method

* add forced_eos_token_id argument in config

* add tests for forced logits processors

* update gen utils tests

* add forced option to tf generate

* remove adjust_logits method from tf models

* update adjust_logits for marian

* delete _force_token_id_to_be_generated method

* style

* import warnings

* pass max_length to _get_logits_processor

* set forced_eos_token_id to None

* set forced attributes in conf utils

* typo

* fix rag generate

* add forced_eos_token_id in rag config

* remove force_bos_token_to_be_generated from BartConfig

* remove _force_token_ids_generation from FSMT

* nit

* fix negative constant

* apply suggestions from code review
parent 22a32cf4
...@@ -131,6 +131,11 @@ class PretrainedConfig(object): ...@@ -131,6 +131,11 @@ class PretrainedConfig(object):
logits when used for generation logits when used for generation
- **return_dict_in_generate** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should - **return_dict_in_generate** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should
return a :class:`~transformers.file_utils.ModelOutput` instead of a :obj:`torch.LongTensor` return a :class:`~transformers.file_utils.ModelOutput` instead of a :obj:`torch.LongTensor`
- **forced_bos_token_id** (:obj:`int`, `optional`) -- The id of the token to force as the first generated token
after the :obj:`decoder_start_token_id`. Useful for multilingual models like :doc:`mBART
<../model_doc/mbart>` where the first generated token needs to be the target language token.
- **forced_eos_token_id** (:obj:`int`, `optional`) -- The id of the token to force as the last generated token
when :obj:`max_length` is reached.
Parameters for fine-tuning tasks Parameters for fine-tuning tasks
...@@ -214,6 +219,8 @@ class PretrainedConfig(object): ...@@ -214,6 +219,8 @@ class PretrainedConfig(object):
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0) self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
self.output_scores = kwargs.pop("output_scores", False) self.output_scores = kwargs.pop("output_scores", False)
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False) self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
# Fine-tuning task arguments # Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None) self.architectures = kwargs.pop("architectures", None)
......
...@@ -520,3 +520,49 @@ class HammingDiversityLogitsProcessor(LogitsProcessor): ...@@ -520,3 +520,49 @@ class HammingDiversityLogitsProcessor(LogitsProcessor):
scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency
return scores return scores
class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
r"""
:class:`~transformers.LogitsProcessor` that enforces the specified token as the first generated token.
Args:
bos_token_id (:obj:`int`):
The id of the token to force as the first generated token.
"""
def __init__(self, bos_token_id: int):
self.bos_token_id = bos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len == 1:
num_tokens = scores.shape[1]
scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf")
scores[:, self.bos_token_id] = 0
return scores
class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
r"""
:class:`~transformers.LogitsProcessor` that enforces the specified token as the last generated token when
:obj:`max_length` is reached.
Args:
max_length (:obj:`int`):
The maximum length of the sequence to be generated.
eos_token_id (:obj:`int`):
The id of the token to force as the last generated token when :obj:`max_length` is reached.
"""
def __init__(self, max_length: int, eos_token_id: int):
self.max_length = max_length
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len == self.max_length - 1:
num_tokens = scores.shape[1]
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
scores[:, self.eos_token_id] = 0
return scores
...@@ -67,6 +67,8 @@ class TFGenerationMixin: ...@@ -67,6 +67,8 @@ class TFGenerationMixin:
attention_mask=None, attention_mask=None,
decoder_start_token_id=None, decoder_start_token_id=None,
use_cache=None, use_cache=None,
forced_bos_token_id=None,
forced_eos_token_id=None,
): ):
r""" r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding, Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
...@@ -137,6 +139,12 @@ class TFGenerationMixin: ...@@ -137,6 +139,12 @@ class TFGenerationMixin:
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding. speed up decoding.
forced_bos_token_id (:obj:`int`, `optional`):
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token
needs to be the target language token.
forced_eos_token_id (:obj:`int`, `optional`):
The id of the token to force as the last generated token when :obj:`max_length` is reached.
model_specific_kwargs: model_specific_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
...@@ -214,6 +222,12 @@ class TFGenerationMixin: ...@@ -214,6 +222,12 @@ class TFGenerationMixin:
decoder_start_token_id = ( decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
) )
forced_bos_token_id = (
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
)
forced_eos_token_id = (
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
)
if input_ids is not None: if input_ids is not None:
batch_size = shape_list(input_ids)[0] # overridden by the input batch_size batch_size = shape_list(input_ids)[0] # overridden by the input batch_size
...@@ -380,6 +394,8 @@ class TFGenerationMixin: ...@@ -380,6 +394,8 @@ class TFGenerationMixin:
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
attention_mask=attention_mask, attention_mask=attention_mask,
use_cache=use_cache, use_cache=use_cache,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
) )
else: else:
output = self._generate_no_beam_search( output = self._generate_no_beam_search(
...@@ -591,6 +607,8 @@ class TFGenerationMixin: ...@@ -591,6 +607,8 @@ class TFGenerationMixin:
encoder_outputs, encoder_outputs,
attention_mask, attention_mask,
use_cache, use_cache,
forced_bos_token_id,
forced_eos_token_id,
): ):
"""Generate sequences for each example with beam search.""" """Generate sequences for each example with beam search."""
...@@ -641,7 +659,11 @@ class TFGenerationMixin: ...@@ -641,7 +659,11 @@ class TFGenerationMixin:
if self.config.is_encoder_decoder and do_sample is False: if self.config.is_encoder_decoder and do_sample is False:
next_token_logits = self.adjust_logits_during_generation( next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length next_token_logits,
cur_len=cur_len,
max_length=max_length,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
) )
# calculate log softmax score # calculate log softmax score
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
...@@ -893,12 +915,21 @@ class TFGenerationMixin: ...@@ -893,12 +915,21 @@ class TFGenerationMixin:
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past) return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
def adjust_logits_during_generation(self, logits, **kwargs): def adjust_logits_during_generation(
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
):
""" """
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in
the generate method. the generate method.
""" """
return logits if cur_len == 1 and forced_bos_token_id is not None:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != forced_bos_token_id, -1e8, logits)
elif cur_len == max_length - 1 and forced_eos_token_id is not None:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != forced_eos_token_id, -1e8, logits)
else:
return logits
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
......
...@@ -24,6 +24,8 @@ from .file_utils import ModelOutput ...@@ -24,6 +24,8 @@ from .file_utils import ModelOutput
from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_process import ( from .generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
LogitsProcessorList, LogitsProcessorList,
MinLengthLogitsProcessor, MinLengthLogitsProcessor,
...@@ -542,7 +544,10 @@ class GenerationMixin: ...@@ -542,7 +544,10 @@ class GenerationMixin:
encoder_input_ids: torch.LongTensor, encoder_input_ids: torch.LongTensor,
bad_words_ids: List[List[int]], bad_words_ids: List[List[int]],
min_length: int, min_length: int,
max_length: int,
eos_token_id: int, eos_token_id: int,
forced_bos_token_id: int,
forced_eos_token_id: int,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
num_beams: int, num_beams: int,
num_beam_groups: int, num_beam_groups: int,
...@@ -567,6 +572,12 @@ class GenerationMixin: ...@@ -567,6 +572,12 @@ class GenerationMixin:
min_length = min_length if min_length is not None else self.config.min_length min_length = min_length if min_length is not None else self.config.min_length
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
forced_bos_token_id = (
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
)
forced_eos_token_id = (
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
)
# instantiate processors list # instantiate processors list
processors = LogitsProcessorList() processors = LogitsProcessorList()
...@@ -595,6 +606,10 @@ class GenerationMixin: ...@@ -595,6 +606,10 @@ class GenerationMixin:
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
if prefix_allowed_tokens_fn is not None: if prefix_allowed_tokens_fn is not None:
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams)) processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams))
if forced_bos_token_id is not None:
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None:
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
return processors return processors
@torch.no_grad() @torch.no_grad()
...@@ -627,6 +642,8 @@ class GenerationMixin: ...@@ -627,6 +642,8 @@ class GenerationMixin:
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None, output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
**model_kwargs, **model_kwargs,
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
r""" r"""
...@@ -720,6 +737,12 @@ class GenerationMixin: ...@@ -720,6 +737,12 @@ class GenerationMixin:
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
forced_bos_token_id (:obj:`int`, `optional`):
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token
needs to be the target language token.
forced_eos_token_id (:obj:`int`, `optional`):
The id of the token to force as the last generated token when :obj:`max_length` is reached.
model_kwargs: model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
...@@ -888,7 +911,10 @@ class GenerationMixin: ...@@ -888,7 +911,10 @@ class GenerationMixin:
encoder_input_ids=encoder_input_ids, encoder_input_ids=encoder_input_ids,
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
min_length=min_length, min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams, num_beams=num_beams,
num_beam_groups=num_beam_groups, num_beam_groups=num_beam_groups,
...@@ -1611,7 +1637,8 @@ class GenerationMixin: ...@@ -1611,7 +1637,8 @@ class GenerationMixin:
) )
next_token_logits = outputs.logits[:, -1, :] next_token_logits = outputs.logits[:, -1, :]
# adjust tokens for Bart, *e.g.* # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation( next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length next_token_logits, cur_len=cur_len, max_length=max_length
) )
...@@ -1866,7 +1893,8 @@ class GenerationMixin: ...@@ -1866,7 +1893,8 @@ class GenerationMixin:
) )
next_token_logits = outputs.logits[:, -1, :] next_token_logits = outputs.logits[:, -1, :]
# adjust token scores (a no-op by default) # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation( next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length next_token_logits, cur_len=cur_len, max_length=max_length
) )
...@@ -2150,7 +2178,8 @@ class GenerationMixin: ...@@ -2150,7 +2178,8 @@ class GenerationMixin:
# select outputs of beams of current group only # select outputs of beams of current group only
next_token_logits = outputs.logits[batch_group_indices, -1, :] next_token_logits = outputs.logits[batch_group_indices, -1, :]
# adjust tokens for Bart, *e.g.* # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation( next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length next_token_logits, cur_len=cur_len, max_length=max_length
) )
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" BART model configuration """ """ BART model configuration """
import warnings
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -72,9 +73,6 @@ class BartConfig(PretrainedConfig): ...@@ -72,9 +73,6 @@ class BartConfig(PretrainedConfig):
just in case (e.g., 512 or 1024 or 2048). just in case (e.g., 512 or 1024 or 2048).
init_std (:obj:`float`, `optional`, defaults to 0.02): init_std (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), only
:obj:`True` for `bart-large-cnn`.
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the encoder. See the `LayerDrop paper <see The LayerDrop probability for the encoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details. https://arxiv.org/abs/1909.11556>`__ for more details.
...@@ -89,6 +87,9 @@ class BartConfig(PretrainedConfig): ...@@ -89,6 +87,9 @@ class BartConfig(PretrainedConfig):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
num_labels: (:obj:`int`, `optional`, defaults to 3): num_labels: (:obj:`int`, `optional`, defaults to 3):
The number of labels to use in :class:`~transformers.BartForSequenceClassification`. The number of labels to use in :class:`~transformers.BartForSequenceClassification`.
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
:obj:`eos_token_id`.
Example:: Example::
...@@ -127,7 +128,6 @@ class BartConfig(PretrainedConfig): ...@@ -127,7 +128,6 @@ class BartConfig(PretrainedConfig):
classifier_dropout=0.0, classifier_dropout=0.0,
scale_embedding=False, scale_embedding=False,
gradient_checkpointing=False, gradient_checkpointing=False,
force_bos_token_to_be_generated=False,
use_cache=True, use_cache=True,
num_labels=3, num_labels=3,
pad_token_id=1, pad_token_id=1,
...@@ -135,6 +135,7 @@ class BartConfig(PretrainedConfig): ...@@ -135,6 +135,7 @@ class BartConfig(PretrainedConfig):
eos_token_id=2, eos_token_id=2,
is_encoder_decoder=True, is_encoder_decoder=True,
decoder_start_token_id=2, decoder_start_token_id=2,
forced_eos_token_id=2,
**kwargs **kwargs
): ):
super().__init__( super().__init__(
...@@ -144,6 +145,7 @@ class BartConfig(PretrainedConfig): ...@@ -144,6 +145,7 @@ class BartConfig(PretrainedConfig):
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder, is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
**kwargs, **kwargs,
) )
...@@ -168,7 +170,14 @@ class BartConfig(PretrainedConfig): ...@@ -168,7 +170,14 @@ class BartConfig(PretrainedConfig):
self.num_hidden_layers = encoder_layers self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.force_bos_token_to_be_generated = force_bos_token_to_be_generated # only relevant for CNN
# ensure backward compatibilty for BART CNN models
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id
warnings.warn(
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
"The config can simply be saved and uploaded again to be fixed."
)
@property @property
def num_attention_heads(self) -> int: def num_attention_heads(self) -> int:
......
...@@ -1344,18 +1344,6 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1344,18 +1344,6 @@ class BartForConditionalGeneration(BartPretrainedModel):
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
return logits
@staticmethod
def _force_token_id_to_be_generated(scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
reordered_past = () reordered_past = ()
......
...@@ -1444,13 +1444,3 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1444,13 +1444,3 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
+ layer_past_key_values[2:], + layer_past_key_values[2:],
) )
return (past[0], reordered_past) return (past[0], reordered_past)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != self.config.bos_token_id, LARGE_NEGATIVE, logits)
elif cur_len == max_length - 1:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
else:
return logits
...@@ -84,6 +84,9 @@ class BlenderbotConfig(PretrainedConfig): ...@@ -84,6 +84,9 @@ class BlenderbotConfig(PretrainedConfig):
Scale embeddings by diving by sqrt(d_model). Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models) Whether or not the model should return the last key/values attentions (not used by all models)
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
:obj:`eos_token_id`.
Example:: Example::
...@@ -129,6 +132,7 @@ class BlenderbotConfig(PretrainedConfig): ...@@ -129,6 +132,7 @@ class BlenderbotConfig(PretrainedConfig):
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
encoder_no_repeat_ngram_size=3, encoder_no_repeat_ngram_size=3,
forced_eos_token_id=2,
**kwargs **kwargs
): ):
super().__init__( super().__init__(
...@@ -138,6 +142,7 @@ class BlenderbotConfig(PretrainedConfig): ...@@ -138,6 +142,7 @@ class BlenderbotConfig(PretrainedConfig):
is_encoder_decoder=is_encoder_decoder, is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
forced_eos_token_id=forced_eos_token_id,
**kwargs, **kwargs,
) )
......
...@@ -1335,16 +1335,6 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): ...@@ -1335,16 +1335,6 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
return logits
@staticmethod
def _force_token_id_to_be_generated(scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
reordered_past = () reordered_past = ()
......
...@@ -1477,10 +1477,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1477,10 +1477,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
+ layer_past_key_values[2:], + layer_past_key_values[2:],
) )
return (past[0], reordered_past) return (past[0], reordered_past)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == max_length - 1:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
else:
return logits
...@@ -84,6 +84,9 @@ class BlenderbotSmallConfig(PretrainedConfig): ...@@ -84,6 +84,9 @@ class BlenderbotSmallConfig(PretrainedConfig):
Scale embeddings by diving by sqrt(d_model). Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models) Whether or not the model should return the last key/values attentions (not used by all models)
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
:obj:`eos_token_id`.
Example:: Example::
...@@ -128,6 +131,7 @@ class BlenderbotSmallConfig(PretrainedConfig): ...@@ -128,6 +131,7 @@ class BlenderbotSmallConfig(PretrainedConfig):
pad_token_id=0, pad_token_id=0,
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
forced_eos_token_id=2,
**kwargs **kwargs
): ):
super().__init__( super().__init__(
...@@ -136,6 +140,7 @@ class BlenderbotSmallConfig(PretrainedConfig): ...@@ -136,6 +140,7 @@ class BlenderbotSmallConfig(PretrainedConfig):
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder, is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
**kwargs, **kwargs,
) )
......
...@@ -1310,16 +1310,6 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): ...@@ -1310,16 +1310,6 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
return logits
@staticmethod
def _force_token_id_to_be_generated(scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
reordered_past = () reordered_past = ()
......
...@@ -1452,10 +1452,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1452,10 +1452,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
+ layer_past_key_values[2:], + layer_past_key_values[2:],
) )
return (past[0], reordered_past) return (past[0], reordered_past)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == max_length - 1:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
else:
return logits
...@@ -111,6 +111,9 @@ class FSMTConfig(PretrainedConfig): ...@@ -111,6 +111,9 @@ class FSMTConfig(PretrainedConfig):
search when at least ``num_beams`` sentences are finished per batch or not. search when at least ``num_beams`` sentences are finished per batch or not.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
:obj:`eos_token_id`.
Examples:: Examples::
...@@ -155,6 +158,7 @@ class FSMTConfig(PretrainedConfig): ...@@ -155,6 +158,7 @@ class FSMTConfig(PretrainedConfig):
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
forced_eos_token_id=2,
**common_kwargs **common_kwargs
): ):
if "hidden_size" in common_kwargs: if "hidden_size" in common_kwargs:
...@@ -166,6 +170,7 @@ class FSMTConfig(PretrainedConfig): ...@@ -166,6 +170,7 @@ class FSMTConfig(PretrainedConfig):
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
is_encoder_decoder=is_encoder_decoder, is_encoder_decoder=is_encoder_decoder,
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,
forced_eos_token_id=forced_eos_token_id,
**common_kwargs, **common_kwargs,
) )
self.langs = langs self.langs = langs
......
...@@ -1210,23 +1210,6 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ...@@ -1210,23 +1210,6 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id) return shift_tokens_right(labels, self.config.pad_token_id)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id)
return logits
def _force_token_ids_generation(self, scores, token_ids) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
if isinstance(token_ids, int):
token_ids = [token_ids]
all_but_token_ids_mask = torch.tensor(
[x for x in range(self.config.tgt_vocab_size) if x not in token_ids],
dtype=torch.long,
device=next(self.parameters()).device,
)
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
scores[:, all_but_token_ids_mask] = -float("inf")
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
reordered_past = [] reordered_past = []
......
...@@ -84,6 +84,9 @@ class MarianConfig(PretrainedConfig): ...@@ -84,6 +84,9 @@ class MarianConfig(PretrainedConfig):
Scale embeddings by diving by sqrt(d_model). Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models) Whether or not the model should return the last key/values attentions (not used by all models)
forced_eos_token_id (:obj:`int`, `optional`, defaults to 0):
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
:obj:`eos_token_id`.
Examples:: Examples::
...@@ -127,6 +130,7 @@ class MarianConfig(PretrainedConfig): ...@@ -127,6 +130,7 @@ class MarianConfig(PretrainedConfig):
gradient_checkpointing=False, gradient_checkpointing=False,
pad_token_id=58100, pad_token_id=58100,
eos_token_id=0, eos_token_id=0,
forced_eos_token_id=0,
**kwargs **kwargs
): ):
super().__init__( super().__init__(
...@@ -134,6 +138,7 @@ class MarianConfig(PretrainedConfig): ...@@ -134,6 +138,7 @@ class MarianConfig(PretrainedConfig):
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder, is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
**kwargs, **kwargs,
) )
......
...@@ -1325,15 +1325,8 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1325,15 +1325,8 @@ class MarianMTModel(MarianPreTrainedModel):
def adjust_logits_during_generation(self, logits, cur_len, max_length): def adjust_logits_during_generation(self, logits, cur_len, max_length):
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token. logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
return logits return logits
@staticmethod
def _force_token_id_to_be_generated(scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
reordered_past = () reordered_past = ()
......
...@@ -1470,10 +1470,17 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1470,10 +1470,17 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
) )
return (past[0], reordered_past) return (past[0], reordered_past)
def adjust_logits_during_generation(self, logits, cur_len, max_length): def adjust_logits_during_generation(
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
):
"""Never predict pad_token_id. Predict </s> when max_length is reached.""" """Never predict pad_token_id. Predict </s> when max_length is reached."""
vocab_range = tf.constant(range(self.config.vocab_size)) vocab_range = tf.constant(range(self.config.vocab_size))
logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits) logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits)
if cur_len == max_length - 1: if cur_len == 1 and forced_bos_token_id is not None:
logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) vocab_range = tf.constant(range(self.config.vocab_size))
return logits return tf.where(vocab_range != forced_bos_token_id, LARGE_NEGATIVE, logits)
elif cur_len == max_length - 1 and forced_eos_token_id is not None:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != forced_eos_token_id, LARGE_NEGATIVE, logits)
else:
return logits
...@@ -84,6 +84,9 @@ class MBartConfig(PretrainedConfig): ...@@ -84,6 +84,9 @@ class MBartConfig(PretrainedConfig):
Scale embeddings by diving by sqrt(d_model). Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models) Whether or not the model should return the last key/values attentions (not used by all models)
forced_eos_token_id (:obj:`int`, `optional`, defaults to 2):
The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to
:obj:`eos_token_id`.
Example:: Example::
...@@ -127,6 +130,7 @@ class MBartConfig(PretrainedConfig): ...@@ -127,6 +130,7 @@ class MBartConfig(PretrainedConfig):
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
forced_eos_token_id=2,
**kwargs **kwargs
): ):
super().__init__( super().__init__(
...@@ -134,6 +138,7 @@ class MBartConfig(PretrainedConfig): ...@@ -134,6 +138,7 @@ class MBartConfig(PretrainedConfig):
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder, is_encoder_decoder=is_encoder_decoder,
forced_eos_token_id=forced_eos_token_id,
**kwargs, **kwargs,
) )
......
...@@ -1344,16 +1344,6 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): ...@@ -1344,16 +1344,6 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id) return shift_tokens_right(labels, self.config.pad_token_id)
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
return logits
@staticmethod
def _force_token_id_to_be_generated(scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf")
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
reordered_past = () reordered_past = ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment