Commit d880a5fb authored by patrickvonplaten's avatar patrickvonplaten Committed by Patrick von Platen
Browse files

finalized PR

parent 2acfe639
...@@ -798,7 +798,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -798,7 +798,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
# create attention mask if necessary # create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids): if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
attention_mask = input_ids.ne(pad_token_id).long() attention_mask = input_ids.ne(pad_token_id).long()
elif attention_mask is None: elif attention_mask is None:
...@@ -989,10 +989,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -989,10 +989,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if unfinished_sents.max() == 0: if unfinished_sents.max() == 0:
break break
# extend attention_mask for new generated input # extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False: if self.config.is_encoder_decoder is False:
attention_mask = torch.cat( attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1 [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
) )
cur_len = cur_len + 1 cur_len = cur_len + 1
...@@ -1078,7 +1078,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1078,7 +1078,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits = next_token_logits / temperature next_token_logits = next_token_logits / temperature
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if self.config.is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here? if (
self.config.is_encoder_decoder
): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search?
scores = self.prepare_scores_for_generation(scores, cur_len, max_length) scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
...@@ -1205,10 +1207,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1205,10 +1207,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if past: if past:
past = self._reorder_cache(past, beam_idx) past = self._reorder_cache(past, beam_idx)
# extend attention_mask for new generated input # extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False: if self.config.is_encoder_decoder is False:
attention_mask = torch.cat( attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1 [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
) )
# update current length # update current length
...@@ -1270,7 +1272,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1270,7 +1272,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# do not return first <BOS> token # do not return first <EOS> token
return decoded[:, 1:] return decoded[:, 1:]
return decoded return decoded
......
...@@ -453,9 +453,7 @@ class BartModelIntegrationTest(unittest.TestCase): ...@@ -453,9 +453,7 @@ class BartModelIntegrationTest(unittest.TestCase):
EXPECTED_SUMMARY_SUBWAY = "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the subway." EXPECTED_SUMMARY_SUBWAY = "Liana Barrientos has been married 10 times, sometimes within two weeks of each other. Prosecutors say the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx. She was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the subway."
dct = tok.batch_encode_plus( dct = tok.batch_encode_plus(
# [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY], [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY],
[IRAN_ARTICLE, ARTICLE_SUBWAY],
# [FRANCE_ARTICLE, SHORTER_ARTICLE],
max_length=1024, max_length=1024,
pad_to_max_length=True, pad_to_max_length=True,
return_tensors="pt", return_tensors="pt",
...@@ -482,9 +480,7 @@ class BartModelIntegrationTest(unittest.TestCase): ...@@ -482,9 +480,7 @@ class BartModelIntegrationTest(unittest.TestCase):
] ]
self.assertListEqual( self.assertListEqual(
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY], [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
[EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
# [EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER],
decoded, decoded,
) )
# TODO(SS): run fairseq again with num_beams=2, min_len=20. # TODO(SS): run fairseq again with num_beams=2, min_len=20.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment