"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "82601f4c1a5c3edb680593bdd9b54abd5846cfa7"
Unverified Commit 7d65697d authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: remove Marian hack (#25294)

Remove Marian hack
parent 14510938
...@@ -474,27 +474,6 @@ class TFGenerationMixin: ...@@ -474,27 +474,6 @@ class TFGenerationMixin:
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`." "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
) )
def adjust_logits_during_generation(
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
):
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
"""
vocab_size = getattr(self.config, "vocab_size", None)
if vocab_size is None and self.config.is_encoder_decoder:
decoder_config = getattr(self.config, "decoder", None)
if decoder_config is not None:
vocab_size = getattr(self.config.decoder, "vocab_size", None)
if cur_len == 1 and forced_bos_token_id is not None:
vocab_range = tf.constant(range(vocab_size))
return tf.where(vocab_range != forced_bos_token_id, -1e8, logits)
elif cur_len == max_length - 1 and forced_eos_token_id is not None:
vocab_range = tf.constant(range(vocab_size))
return tf.where(vocab_range != forced_eos_token_id, -1e8, logits)
else:
return logits
def compute_transition_scores( def compute_transition_scores(
self, self,
sequences: tf.Tensor, sequences: tf.Tensor,
......
...@@ -578,12 +578,6 @@ class GenerationMixin: ...@@ -578,12 +578,6 @@ class GenerationMixin:
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
return inputs, input_name, model_kwargs return inputs, input_name, model_kwargs
def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
"""
return logits
def _maybe_initialize_input_ids_for_generation( def _maybe_initialize_input_ids_for_generation(
self, self,
inputs: Optional[torch.Tensor] = None, inputs: Optional[torch.Tensor] = None,
...@@ -3060,9 +3054,6 @@ class GenerationMixin: ...@@ -3060,9 +3054,6 @@ class GenerationMixin:
continue # don't waste resources running the code we don't need continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :] next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax( next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1 next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
...@@ -3388,9 +3379,6 @@ class GenerationMixin: ...@@ -3388,9 +3379,6 @@ class GenerationMixin:
next_token_logits = outputs.logits[:, -1, :] next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax( next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1 next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
...@@ -3751,9 +3739,6 @@ class GenerationMixin: ...@@ -3751,9 +3739,6 @@ class GenerationMixin:
# select outputs of beams of current group only # select outputs of beams of current group only
next_token_logits = outputs.logits[batch_group_indices, -1, :] next_token_logits = outputs.logits[batch_group_indices, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax( next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1 next_token_logits, dim=-1
) # (batch_size * group_size, vocab_size) ) # (batch_size * group_size, vocab_size)
...@@ -4110,9 +4095,6 @@ class GenerationMixin: ...@@ -4110,9 +4095,6 @@ class GenerationMixin:
continue # don't waste resources running the code we don't need continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :] next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax( next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1 next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
......
...@@ -1524,10 +1524,6 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1524,10 +1524,6 @@ class MarianMTModel(MarianPreTrainedModel):
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def adjust_logits_during_generation(self, logits, cur_len):
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
return logits
@staticmethod @staticmethod
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
......
...@@ -1443,18 +1443,3 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1443,18 +1443,3 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def adjust_logits_during_generation(
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
):
"""Never predict pad_token_id. Predict </s> when max_length is reached."""
vocab_range = tf.constant(range(self.config.vocab_size))
logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits)
if cur_len == 1 and forced_bos_token_id is not None:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != forced_bos_token_id, LARGE_NEGATIVE, logits)
elif cur_len == max_length - 1 and forced_eos_token_id is not None:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != forced_eos_token_id, LARGE_NEGATIVE, logits)
else:
return logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment