"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4c7f564f9a8c69430c55dce1b1f93c9e65d5944d"
Unverified Commit 8a377c3d authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[fix] Move _adjust_logits above postprocess to fix Marian.generate (#5126)

parent 3d3e605a
...@@ -993,7 +993,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -993,7 +993,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
def prepare_logits_for_generation(self, logits, cur_len, max_length): def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1: if cur_len == 1:
self._force_token_ids_generation(logits, self.config.bos_token_id) self._force_token_ids_generation(logits, self.config.bos_token_id)
if cur_len == max_length - 1 and self.config.eos_token_id is not None: if cur_len == max_length - 1 and self.config.eos_token_id is not None:
......
...@@ -46,7 +46,7 @@ class MarianMTModel(BartForConditionalGeneration): ...@@ -46,7 +46,7 @@ class MarianMTModel(BartForConditionalGeneration):
""" """
def prepare_logits_for_generation(self, logits, cur_len, max_length): def adjust_logits_during_generation(self, logits, cur_len, max_length):
logits[:, self.config.pad_token_id] = float("-inf") logits[:, self.config.pad_token_id] = float("-inf")
if cur_len == max_length - 1 and self.config.eos_token_id is not None: if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id) self._force_token_ids_generation(logits, self.config.eos_token_id)
......
...@@ -792,7 +792,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -792,7 +792,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
def prepare_inputs_for_generation(self, input_ids, **kwargs): def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids} return {"input_ids": input_ids}
def prepare_logits_for_generation(self, logits, **kwargs): def adjust_logits_during_generation(self, logits, **kwargs):
return logits return logits
def _use_cache(self, outputs, use_cache): def _use_cache(self, outputs, use_cache):
...@@ -1396,6 +1396,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1396,6 +1396,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# if model has past, then set the past variable to speed up decoding # if model has past, then set the past variable to speed up decoding
if self._use_cache(outputs, use_cache): if self._use_cache(outputs, use_cache):
past = outputs[1] past = outputs[1]
if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solution
next_token_logits = self.adjust_logits_during_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
...@@ -1413,10 +1418,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1413,10 +1418,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
num_beams=num_beams, num_beams=num_beams,
) )
if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solution
scores = self.prepare_logits_for_generation(scores, cur_len=cur_len, max_length=max_length)
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format( assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
scores.shape, (batch_size * num_beams, vocab_size) scores.shape, (batch_size * num_beams, vocab_size)
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment