Commit 365ccd0a authored by patrickvonplaten's avatar patrickvonplaten
Browse files

make if statements cleaner for prepare_inputs_for_generation

parent d039c679
...@@ -491,8 +491,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -491,8 +491,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
return self.lm_head return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs): def prepare_inputs_for_generation(self, input_ids, **kwargs):
# inputs_ids should only be composed of last token if past is in kwargs and defined # only last token for inputs_ids if past is defined in kwargs
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids if 'past' in kwargs and kwargs['past']:
input_ids = input_ids[:, -1].unsqueeze(-1)
inputs = {"input_ids": input_ids} inputs = {"input_ids": input_ids}
inputs.update(kwargs) inputs.update(kwargs)
......
...@@ -560,8 +560,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -560,8 +560,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
return self.lm_head return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs): def prepare_inputs_for_generation(self, input_ids, **kwargs):
# inputs_ids should only be composed of last token if past is in kwargs and defined # only last token for inputs_ids if past is defined in kwargs
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids if 'past' in kwargs and kwargs['past']:
input_ids = input_ids[:, -1].unsqueeze(-1)
inputs = {"input_ids": input_ids} inputs = {"input_ids": input_ids}
inputs.update(kwargs) inputs.update(kwargs)
......
...@@ -540,8 +540,8 @@ class PreTrainedModel(nn.Module): ...@@ -540,8 +540,8 @@ class PreTrainedModel(nn.Module):
return {"input_ids": input_ids} return {"input_ids": input_ids}
def _do_output_past(self, outputs): def _do_output_past(self, outputs):
# TODO: might be better to write a self.do_output_past method for each individual class as is done for # TODO: might be better to write a self.do_output_past method for each
# prepare_inputs_for_generation # individual class as is done for prepare_inputs_for_generation
has_output_past = hasattr(self.config, 'output_past') and self.config.output_past has_output_past = hasattr(self.config, 'output_past') and self.config.output_past
has_multiple_outputs = len(outputs) > 1 has_multiple_outputs = len(outputs) > 1
has_mem_len = hasattr(self, 'mem_len') has_mem_len = hasattr(self, 'mem_len')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment