Commit deff792b authored by patrickvonplaten's avatar patrickvonplaten
Browse files

add prepare inputs for transfo_xl and xlnet

parent 9398058e
...@@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
return self.out_layer return self.out_layer
else: else:
return self.crit.out_layers[-1] return self.crit.out_layers[-1]
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
inputs = {"input_ids": input_ids}
# if past is defined in model kwargs then use it for faster decoding
if 'past' in model_kwargs and model_kwargs['past']:
inputs['mems'] = model_kwargs['past']
return inputs
...@@ -540,15 +540,14 @@ class PreTrainedModel(nn.Module): ...@@ -540,15 +540,14 @@ class PreTrainedModel(nn.Module):
return {"input_ids": input_ids} return {"input_ids": input_ids}
def _do_output_past(self, outputs): def _do_output_past(self, outputs):
# TODO: might be better to write a self.do_output_past method for each
# individual class as is done for prepare_inputs_for_generation
has_output_past = hasattr(self.config, 'output_past') and self.config.output_past has_output_past = hasattr(self.config, 'output_past') and self.config.output_past
has_multiple_outputs = len(outputs) > 1 has_mem_len = hasattr(self.config, 'mem_len') and self.config.mem_len
has_mem_len = hasattr(self.config, 'mem_len')
if has_output_past and has_multiple_outputs and not has_mem_len: if has_output_past and not has_mem_len and len(outputs) > 1:
return True return True
# TODO: Add cases for (xlnet, transfo_xl) using mem_len elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
return True
return False return False
@torch.no_grad() @torch.no_grad()
...@@ -921,7 +920,8 @@ class PreTrainedModel(nn.Module): ...@@ -921,7 +920,8 @@ class PreTrainedModel(nn.Module):
if past: if past:
reordered_past = [] reordered_past = []
for layer_past in past: for layer_past in past:
# copy the relevant beam idx past to past # get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx] reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1) reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches # check that shape matches
......
...@@ -1028,7 +1028,13 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1028,7 +1028,13 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
) )
target_mapping[0, 0, -1] = 1.0 target_mapping[0, 0, -1] = 1.0
return {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping} inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping}
# if past is defined in model kwargs then use it for faster decoding
if 'past' in model_kwargs and model_kwargs['past']:
inputs['mems'] = model_kwargs['past']
return inputs
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment