"web/git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "e651be551c02073a8d5b639f1f8c1bf11d78cd82"
Commit 7e0c5c73 authored by patrickvonplaten's avatar patrickvonplaten
Browse files

changed do_output_past function to check for self.config.output_past instead of self.output_past

parent eeaa402c
...@@ -539,10 +539,10 @@ class PreTrainedModel(nn.Module): ...@@ -539,10 +539,10 @@ class PreTrainedModel(nn.Module):
def prepare_inputs_for_generation(self, input_ids, **kwargs): def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids} return {"input_ids": input_ids}
def _has_past(self, outputs): def _do_output_past(self, outputs):
# TODO: might be better to write a self.has_past method for each individual class as is done for # TODO: might be better to write a self.do_output_past method for each individual class as is done for
# prepare_inputs_for_generation # prepare_inputs_for_generation
if hasattr(self, 'output_past') and self.output_past and len(outputs) > 1: if hasattr(self.config, 'output_past') and self.config.output_past and len(outputs) > 1 and not hasattr(self, 'mem_len'):
return True return True
# TODO: Add cases for (xlnet, transfo_xl) using mem_len # TODO: Add cases for (xlnet, transfo_xl) using mem_len
return False return False
...@@ -732,7 +732,7 @@ class PreTrainedModel(nn.Module): ...@@ -732,7 +732,7 @@ class PreTrainedModel(nn.Module):
next_token_logits = outputs[0][:, -1, :] next_token_logits = outputs[0][:, -1, :]
# if model has past, then set the past variable to speed up decoding # if model has past, then set the past variable to speed up decoding
if self._has_past(outputs): if self._do_output_past(outputs):
past = outputs[1] past = outputs[1]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
...@@ -819,7 +819,7 @@ class PreTrainedModel(nn.Module): ...@@ -819,7 +819,7 @@ class PreTrainedModel(nn.Module):
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding # if model has past, then set the past variable to speed up decoding
if self._has_past(outputs): if self._do_output_past(outputs):
past = outputs[1] past = outputs[1]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment