"...resnet50_tensorflow.git" did not exist on "72f5834cc19bcfd3aef795dd926575fd9e0db802"
Commit fc84bd52 authored by patrickvonplaten's avatar patrickvonplaten
Browse files

adapt style to predefined style layout

parent deff792b
...@@ -492,7 +492,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -492,7 +492,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def prepare_inputs_for_generation(self, input_ids, **kwargs): def prepare_inputs_for_generation(self, input_ids, **kwargs):
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if 'past' in kwargs and kwargs['past']: if "past" in kwargs and kwargs["past"]:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
inputs = {"input_ids": input_ids} inputs = {"input_ids": input_ids}
......
...@@ -561,7 +561,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -561,7 +561,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def prepare_inputs_for_generation(self, input_ids, **kwargs): def prepare_inputs_for_generation(self, input_ids, **kwargs):
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if 'past' in kwargs and kwargs['past']: if "past" in kwargs and kwargs["past"]:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
inputs = {"input_ids": input_ids} inputs = {"input_ids": input_ids}
......
...@@ -935,7 +935,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -935,7 +935,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
inputs = {"input_ids": input_ids} inputs = {"input_ids": input_ids}
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
if 'past' in model_kwargs and model_kwargs['past']: if "past" in model_kwargs and model_kwargs["past"]:
inputs['mems'] = model_kwargs['past'] inputs["mems"] = model_kwargs["past"]
return inputs return inputs
...@@ -540,8 +540,8 @@ class PreTrainedModel(nn.Module): ...@@ -540,8 +540,8 @@ class PreTrainedModel(nn.Module):
return {"input_ids": input_ids} return {"input_ids": input_ids}
def _do_output_past(self, outputs): def _do_output_past(self, outputs):
has_output_past = hasattr(self.config, 'output_past') and self.config.output_past has_output_past = hasattr(self.config, "output_past") and self.config.output_past
has_mem_len = hasattr(self.config, 'mem_len') and self.config.mem_len has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
if has_output_past and not has_mem_len and len(outputs) > 1: if has_output_past and not has_mem_len and len(outputs) > 1:
return True return True
......
...@@ -1031,8 +1031,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1031,8 +1031,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping} inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping}
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
if 'past' in model_kwargs and model_kwargs['past']: if "past" in model_kwargs and model_kwargs["past"]:
inputs['mems'] = model_kwargs['past'] inputs["mems"] = model_kwargs["past"]
return inputs return inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment