".github/vscode:/vscode.git/clone" did not exist on "7dfdc0a5563abe80a85f9f7fa0c3b2ef458e7783"
Commit 267587c2 authored by patrickvonplaten's avatar patrickvonplaten
Browse files

add and improve comments

parent d891fd0a
......@@ -491,7 +491,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# inputs_ids contain only last token if past is in kwargs and defined
# inputs_ids should only be composed of last token if past is in kwargs and defined
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids
inputs = {"input_ids": input_ids}
......
......@@ -560,7 +560,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# inputs_ids contain only last token if past is in kwargs and defined
# inputs_ids should only be composed of last token if past is in kwargs and defined
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids
inputs = {"input_ids": input_ids}
......
......@@ -732,6 +732,7 @@ class PreTrainedModel(nn.Module):
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
# if model has past, then set the past parameter to speed up decoding
if self._has_past(outputs):
past = outputs[1]
......@@ -819,6 +820,7 @@ class PreTrainedModel(nn.Module):
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past parameter to speed up decoding
if self._has_past(outputs):
past = outputs[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment