Unverified Commit 812def00 authored by tommccoy's avatar tommccoy Committed by GitHub
Browse files

fix use of mems in Transformer-XL (#4826)

Fixed duplicated memory use in Transformer-XL generation leading to bad predictions and performance.
parent 306f1a26
...@@ -1016,11 +1016,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1016,11 +1016,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
return self.crit.out_layers[-1] return self.crit.out_layers[-1]
def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs):
inputs = {"input_ids": input_ids} inputs = {}
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
if past: if past:
inputs["mems"] = past inputs["mems"] = past
inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1)
else:
inputs["input_ids"] = input_ids
return inputs return inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment