".circleci/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "71c2ae7793de0c1ceb9cc96ef6d14bf3e302fb4f"
Unverified Commit 492bea9a authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2292 from patrickvonplaten/add_cached_past_for_language_generation

Add cached past for language generation
parents e213900f fc84bd52
...@@ -490,6 +490,15 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -490,6 +490,15 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if "past" in kwargs and kwargs["past"]:
input_ids = input_ids[:, -1].unsqueeze(-1)
inputs = {"input_ids": input_ids}
inputs.update(kwargs)
return inputs
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
......
...@@ -559,6 +559,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -559,6 +559,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if "past" in kwargs and kwargs["past"]:
input_ids = input_ids[:, -1].unsqueeze(-1)
inputs = {"input_ids": input_ids}
inputs.update(kwargs)
return inputs
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
......
...@@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
return self.out_layer return self.out_layer
else: else:
return self.crit.out_layers[-1] return self.crit.out_layers[-1]
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
inputs = {"input_ids": input_ids}
# if past is defined in model kwargs then use it for faster decoding
if "past" in model_kwargs and model_kwargs["past"]:
inputs["mems"] = model_kwargs["past"]
return inputs
...@@ -539,6 +539,17 @@ class PreTrainedModel(nn.Module): ...@@ -539,6 +539,17 @@ class PreTrainedModel(nn.Module):
def prepare_inputs_for_generation(self, input_ids, **kwargs): def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids} return {"input_ids": input_ids}
def _do_output_past(self, outputs):
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
if has_output_past and not has_mem_len and len(outputs) > 1:
return True
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
return True
return False
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self,
...@@ -757,14 +768,17 @@ class PreTrainedModel(nn.Module): ...@@ -757,14 +768,17 @@ class PreTrainedModel(nn.Module):
# current position / max lengths / length of generated sentences / unfinished sentences # current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1) unfinished_sents = input_ids.new(batch_size).fill_(1)
# TODO: add cached compute states past = None
pasts = None
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs) outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :] next_token_logits = outputs[0][:, -1, :]
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
past = outputs[1]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
for i in range(batch_size): for i in range(batch_size):
...@@ -838,15 +852,19 @@ class PreTrainedModel(nn.Module): ...@@ -838,15 +852,19 @@ class PreTrainedModel(nn.Module):
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states # cache compute states
pasts = None # self.prepare_pasts() past = None
# done sentences # done sentences
done = [False for _ in range(batch_size)] done = [False for _ in range(batch_size)]
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size) outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size) scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
past = outputs[1]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
...@@ -935,13 +953,22 @@ class PreTrainedModel(nn.Module): ...@@ -935,13 +953,22 @@ class PreTrainedModel(nn.Module):
beam_words = input_ids.new([x[1] for x in next_batch_beam]) beam_words = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam]) beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order batch and internal states # re-order batch
input_ids = input_ids[beam_idx, :] input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1) input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
# TODO: Activate cache
# for k in cache.keys(): # re-order internal states
# if k != 'slen': if past:
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx]) reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
# update current length # update current length
cur_len = cur_len + 1 cur_len = cur_len + 1
......
...@@ -1028,7 +1028,13 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1028,7 +1028,13 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
) )
target_mapping[0, 0, -1] = 1.0 target_mapping[0, 0, -1] = 1.0
return {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping} inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping}
# if past is defined in model kwargs then use it for faster decoding
if "past" in model_kwargs and model_kwargs["past"]:
inputs["mems"] = model_kwargs["past"]
return inputs
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment