Commit d039c679 authored by patrickvonplaten's avatar patrickvonplaten
Browse files

better naming for if statement

parent 7e0c5c73
...@@ -542,7 +542,11 @@ class PreTrainedModel(nn.Module): ...@@ -542,7 +542,11 @@ class PreTrainedModel(nn.Module):
def _do_output_past(self, outputs): def _do_output_past(self, outputs):
# TODO: might be better to write a self.do_output_past method for each individual class as is done for # TODO: might be better to write a self.do_output_past method for each individual class as is done for
# prepare_inputs_for_generation # prepare_inputs_for_generation
if hasattr(self.config, 'output_past') and self.config.output_past and len(outputs) > 1 and not hasattr(self, 'mem_len'): has_output_past = hasattr(self.config, 'output_past') and self.config.output_past
has_multiple_outputs = len(outputs) > 1
has_mem_len = hasattr(self, 'mem_len')
if has_output_past and has_multiple_outputs and not has_mem_len:
return True return True
# TODO: Add cases for (xlnet, transfo_xl) using mem_len # TODO: Add cases for (xlnet, transfo_xl) using mem_len
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment