Commit b90791e9 authored by Rostislav Nedelchev's avatar Rostislav Nedelchev
Browse files

fixed XLNet attenttion output for both attention streams

parent b0ee7c7d
......@@ -581,7 +581,7 @@ class XLNetModel(XLNetPreTrainedModel):
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
......@@ -878,7 +878,7 @@ class XLNetModel(XLNetPreTrainedModel):
hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
outputs = outputs + (hidden_states,)
if self.output_attentions:
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
attentions = tuple(tuple(att_stream.permute(2, 3, 0, 1).contiguous() for att_stream in t) for t in attentions)
outputs = outputs + (attentions,)
return outputs # outputs, (new_mems), (hidden_states), (attentions)
......@@ -911,7 +911,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
......@@ -993,7 +993,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
......@@ -1093,7 +1093,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
......@@ -1178,7 +1178,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
......@@ -1292,7 +1292,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment