Commit 41b437ea authored by patrickvonplaten's avatar patrickvonplaten Committed by Patrick von Platen
Browse files

add draft version of propsoed changes for ROGUE score

parent a5751f75
......@@ -846,7 +846,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs = input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
# eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
bos_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
)
......@@ -1079,10 +1080,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits = next_token_logits / temperature
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if (
self.config.is_encoder_decoder and do_sample is False
): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# if (
# self.config.is_encoder_decoder and do_sample is False
# ): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
# scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
......@@ -1271,9 +1272,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
if self.config.is_encoder_decoder:
# if self.config.is_encoder_decoder:
# do not return first <EOS> token
return decoded[:, 1:]
# return decoded[:, 1:]
return decoded
# force one of token_ids to be generated by setting prob of all other tokens to 0.
......
......@@ -214,6 +214,9 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32,
max_position_embeddings=48,
output_past=output_past,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
)
return config, input_ids, batch_size
......@@ -468,7 +471,8 @@ class BartModelIntegrationTest(unittest.TestCase):
attention_mask=dct["attention_mask"].to(torch_device),
num_beams=4,
length_penalty=2.0,
max_length=max_length + 2,
# max_length=max_length + 2,
max_length=max_length + 1,
min_length=min_length + 1,
no_repeat_ngram_size=3,
do_sample=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment