Commit 7351a8db authored by Patrick von Platen's avatar Patrick von Platen
Browse files

re-add scoring filtering

parent 9b8ee8ce
...@@ -1084,10 +1084,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1084,10 +1084,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits = next_token_logits / temperature next_token_logits = next_token_logits / temperature
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
# if ( if self.config.is_encoder_decoder and do_sample is False:
# self.config.is_encoder_decoder and do_sample is False # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
# ): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length: if eos_token_ids is not None and cur_len < min_length:
...@@ -1279,10 +1278,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1279,10 +1278,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert (len(hypo) == max_length for hypo in best) assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
# if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# do not return first <EOS> token # do not return first <EOS> token
# return decoded[:, 1:] return decoded[:, 1:]
return decoded # return decoded
# force one of token_ids to be generated by setting prob of all other tokens to 0. # force one of token_ids to be generated by setting prob of all other tokens to 0.
def _force_token_ids_generation(self, scores, token_ids): def _force_token_ids_generation(self, scores, token_ids):
......
...@@ -471,8 +471,7 @@ class BartModelIntegrationTest(unittest.TestCase): ...@@ -471,8 +471,7 @@ class BartModelIntegrationTest(unittest.TestCase):
attention_mask=dct["attention_mask"].to(torch_device), attention_mask=dct["attention_mask"].to(torch_device),
num_beams=4, num_beams=4,
length_penalty=2.0, length_penalty=2.0,
# max_length=max_length + 2, max_length=max_length + 2,
max_length=max_length + 1,
min_length=min_length + 1, min_length=min_length + 1,
no_repeat_ngram_size=3, no_repeat_ngram_size=3,
do_sample=False, do_sample=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment