Commit f18ac4c2 authored by patrickvonplaten's avatar patrickvonplaten
Browse files

fix sequence length for prepare_inputs for xlnet

parent 359dc438
...@@ -1012,11 +1012,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1012,11 +1012,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
# Add dummy token at the end (no attention on this one) # Add dummy token at the end (no attention on this one)
effective_batch_size = input_ids.shape[0] effective_batch_size = input_ids.shape[0]
sequence_length = input_ids.shape[1]
dummy_token = torch.zeros((effective_batch_size, 1), dtype=torch.long, device=input_ids.device) dummy_token = torch.zeros((effective_batch_size, 1), dtype=torch.long, device=input_ids.device)
input_ids = torch.cat([input_ids, dummy_token], dim=1) input_ids = torch.cat([input_ids, dummy_token], dim=1)
# Build permutation mask so that previous tokens don't see last token # Build permutation mask so that previous tokens don't see last token
sequence_length = input_ids.shape[1]
perm_mask = torch.zeros( perm_mask = torch.zeros(
(effective_batch_size, sequence_length, sequence_length), dtype=torch.float, device=input_ids.device (effective_batch_size, sequence_length, sequence_length), dtype=torch.float, device=input_ids.device
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment