"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "83d969e3975d340ef980db59b07954a67d08ce6f"
Commit d891fd0a authored by patrickvonplaten's avatar patrickvonplaten
Browse files

add past hidden key states for more efficient language generation & add...

add past hidden key states for more efficient language generation & add prepare_inputs for gpt2 and ctrl model
parent aeef4823
...@@ -490,6 +490,14 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -490,6 +490,14 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# inputs_ids contain only last token if past is in kwargs and defined
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids
inputs = {"input_ids": input_ids}
inputs.update(kwargs)
return inputs
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
......
...@@ -559,6 +559,14 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -559,6 +559,14 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# inputs_ids contain only last token if past is in kwargs and defined
input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids
inputs = {"input_ids": input_ids}
inputs.update(kwargs)
return inputs
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import logging import logging
import os import os
import ipdb
import torch import torch
from torch import nn from torch import nn
...@@ -539,6 +540,14 @@ class PreTrainedModel(nn.Module): ...@@ -539,6 +540,14 @@ class PreTrainedModel(nn.Module):
def prepare_inputs_for_generation(self, input_ids, **kwargs): def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids} return {"input_ids": input_ids}
def _has_past(self, outputs):
# TODO: might be better to write a self.has_past method for each individual class as is done for
# prepare_inputs_for_generation
if hasattr(self, 'output_past') and self.output_past and len(outputs) > 1:
return True
# TODO: Add cases for (xlnet, transfo_xl) using mem_len
return False
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self,
...@@ -716,14 +725,16 @@ class PreTrainedModel(nn.Module): ...@@ -716,14 +725,16 @@ class PreTrainedModel(nn.Module):
# current position / max lengths / length of generated sentences / unfinished sentences # current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1) unfinished_sents = input_ids.new(batch_size).fill_(1)
# TODO: add cached compute states past = None
pasts = None
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs) outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :] next_token_logits = outputs[0][:, -1, :]
if self._has_past(outputs):
past = outputs[1]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
for i in range(batch_size): for i in range(batch_size):
...@@ -782,6 +793,7 @@ class PreTrainedModel(nn.Module): ...@@ -782,6 +793,7 @@ class PreTrainedModel(nn.Module):
): ):
""" Generate sequences for each example with beam search. """ Generate sequences for each example with beam search.
""" """
ipdb.set_trace()
# Expand input to num beams # Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len) input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len) input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
...@@ -797,15 +809,18 @@ class PreTrainedModel(nn.Module): ...@@ -797,15 +809,18 @@ class PreTrainedModel(nn.Module):
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states # cache compute states
pasts = None # self.prepare_pasts() past = None
# done sentences # done sentences
done = [False for _ in range(batch_size)] done = [False for _ in range(batch_size)]
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size) outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size) scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
if self._has_past(outputs):
past = outputs[1]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment