"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e2c935f5615a3c15ee7439fa8a560edd5f13a457"
Commit 41b437ea authored by patrickvonplaten's avatar patrickvonplaten Committed by Patrick von Platen
Browse files

add draft version of propsoed changes for ROGUE score

parent a5751f75
...@@ -846,7 +846,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -846,7 +846,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs = input_ids encoder_inputs = input_ids
input_ids = torch.full( input_ids = torch.full(
(effective_batch_size * num_beams, 1), (effective_batch_size * num_beams, 1),
eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this # eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
bos_token_id,
dtype=torch.long, dtype=torch.long,
device=next(self.parameters()).device, device=next(self.parameters()).device,
) )
...@@ -1079,10 +1080,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1079,10 +1080,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits = next_token_logits / temperature next_token_logits = next_token_logits / temperature
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if ( # if (
self.config.is_encoder_decoder and do_sample is False # self.config.is_encoder_decoder and do_sample is False
): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here # ): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
scores = self.prepare_scores_for_generation(scores, cur_len, max_length) # scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length: if eos_token_ids is not None and cur_len < min_length:
...@@ -1271,9 +1272,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1271,9 +1272,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert (len(hypo) == max_length for hypo in best) assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
if self.config.is_encoder_decoder: # if self.config.is_encoder_decoder:
# do not return first <EOS> token # do not return first <EOS> token
return decoded[:, 1:] # return decoded[:, 1:]
return decoded return decoded
# force one of token_ids to be generated by setting prob of all other tokens to 0. # force one of token_ids to be generated by setting prob of all other tokens to 0.
......
...@@ -214,6 +214,9 @@ class BartHeadTests(unittest.TestCase): ...@@ -214,6 +214,9 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32, decoder_ffn_dim=32,
max_position_embeddings=48, max_position_embeddings=48,
output_past=output_past, output_past=output_past,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
) )
return config, input_ids, batch_size return config, input_ids, batch_size
...@@ -468,7 +471,8 @@ class BartModelIntegrationTest(unittest.TestCase): ...@@ -468,7 +471,8 @@ class BartModelIntegrationTest(unittest.TestCase):
attention_mask=dct["attention_mask"].to(torch_device), attention_mask=dct["attention_mask"].to(torch_device),
num_beams=4, num_beams=4,
length_penalty=2.0, length_penalty=2.0,
max_length=max_length + 2, # max_length=max_length + 2,
max_length=max_length + 1,
min_length=min_length + 1, min_length=min_length + 1,
no_repeat_ngram_size=3, no_repeat_ngram_size=3,
do_sample=False, do_sample=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment