Commit 7cba11fb authored by Patrick von Platen's avatar Patrick von Platen
Browse files

better naming

parent ff648221
...@@ -943,7 +943,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -943,7 +943,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return outputs return outputs
@staticmethod @staticmethod
def prepare_inputs_for_generation_1(input_ids, past, decoder_input_ids, attention_mask): def prepare_inputs_for_generation_bart(input_ids, past, decoder_input_ids, attention_mask):
if past is None: # first step if past is None: # first step
encoder_outputs, decoder_cached_states = None, None encoder_outputs, decoder_cached_states = None, None
else: else:
...@@ -993,7 +993,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -993,7 +993,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return self.lm_head return self.lm_head
@torch.no_grad() @torch.no_grad()
def generate_1( def generate_bart(
self, self,
input_ids, input_ids,
attention_mask=None, attention_mask=None,
...@@ -1113,7 +1113,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1113,7 +1113,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
self.model.decoder.generation_mode = True # tells decoder not to use causal mask self.model.decoder.generation_mode = True # tells decoder not to use causal mask
for step in range(max_length + 1): for step in range(max_length + 1):
decoder_input_ids = prev_output_tokens.clone() decoder_input_ids = prev_output_tokens.clone()
model_inputs = self.prepare_inputs_for_generation_1( model_inputs = self.prepare_inputs_for_generation_bart(
input_ids, decoder_cache, decoder_input_ids, attention_mask, input_ids, decoder_cache, decoder_input_ids, attention_mask,
) )
outputs = self(**model_inputs) outputs = self(**model_inputs)
......
...@@ -411,7 +411,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -411,7 +411,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
else: else:
raise EnvironmentError( raise EnvironmentError(
"Error no file named {} found in directory {} or `from_tf` set to False".format( "Error no file named {} found in directory {} or `from_tf` set to False".format(
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index",], [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
pretrained_model_name_or_path, pretrained_model_name_or_path,
) )
) )
...@@ -816,7 +816,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -816,7 +816,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
effective_batch_size * num_beams, input_ids_len effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len) ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
# TODO (PVP): check eos_token_id
# TODO (PVP): probably not the best way to check whether model is encoder decoder # TODO (PVP): probably not the best way to check whether model is encoder decoder
is_encoder_decoder = ( is_encoder_decoder = (
hasattr(self, "model") and hasattr(self.model, "decoder") and hasattr(self.model, "encoder") hasattr(self, "model") and hasattr(self.model, "decoder") and hasattr(self.model, "encoder")
...@@ -829,7 +828,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -829,7 +828,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs = input_ids encoder_inputs = input_ids
input_ids = torch.full( input_ids = torch.full(
(effective_batch_size * num_beams, 1), (effective_batch_size * num_beams, 1),
# eos_token_id, # Why eos_token_id here? bos_token_id makes more sense no? # eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
bos_token_id, bos_token_id,
dtype=torch.long, dtype=torch.long,
device=next(self.parameters()).device, device=next(self.parameters()).device,
......
...@@ -427,7 +427,7 @@ class BartModelIntegrationTest(unittest.TestCase): ...@@ -427,7 +427,7 @@ class BartModelIntegrationTest(unittest.TestCase):
text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian" text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian"
tokens = tok.encode(text, return_tensors="pt").to(torch_device) tokens = tok.encode(text, return_tensors="pt").to(torch_device)
extra_len = 20 extra_len = 20
gen_tokens_bart = hf.generate_1(tokens, num_beams=4, max_length=extra_len,) # repetition_penalty=10., gen_tokens_bart = hf.generate_bart(tokens, num_beams=3, max_length=extra_len,) # repetition_penalty=10.,
gen_tokens = hf.generate( gen_tokens = hf.generate(
tokens, num_beams=4, max_length=extra_len + 2, do_sample=False tokens, num_beams=4, max_length=extra_len + 2, do_sample=False
) # repetition_penalty=10., ) # repetition_penalty=10.,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment