Commit 76c0bc06 authored by Rostislav Nedelchev's avatar Rostislav Nedelchev
Browse files

[XLNet] Changed post-processing of attention w.r.t to target_mapping

Whenever target_mapping is provided to the input, XLNet outputs two different attention streams.
Based on that the attention output would be on of the two:
- a list of tensors (usual case for most transformers)
- a list of 2-tuples of tensors, one tesor for each of attention streams
Docs and unit-tests have been updated
parent b90791e9
...@@ -581,8 +581,9 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -581,8 +581,9 @@ class XLNetModel(XLNetPreTrainedModel):
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``) **attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``.
Examples:: Examples::
...@@ -878,7 +879,11 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -878,7 +879,11 @@ class XLNetModel(XLNetPreTrainedModel):
hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states) hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
outputs = outputs + (hidden_states,) outputs = outputs + (hidden_states,)
if self.output_attentions: if self.output_attentions:
attentions = tuple(tuple(att_stream.permute(2, 3, 0, 1).contiguous() for att_stream in t) for t in attentions) if target_mapping is not None:
# when target_mapping is provided, there are 2-tuple of attentions
attentions = tuple(tuple(att_stream.permute(2, 3, 0, 1).contiguous() for att_stream in t) for t in attentions)
else:
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs = outputs + (attentions,) outputs = outputs + (attentions,)
return outputs # outputs, (new_mems), (hidden_states), (attentions) return outputs # outputs, (new_mems), (hidden_states), (attentions)
...@@ -911,8 +916,9 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -911,8 +916,9 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``) **attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``.
Examples:: Examples::
...@@ -993,8 +999,9 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -993,8 +999,9 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``) **attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``.
Examples:: Examples::
...@@ -1093,8 +1100,9 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1093,8 +1100,9 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``) **attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``.
Examples:: Examples::
...@@ -1178,8 +1186,9 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1178,8 +1186,9 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``) **attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``.
Examples:: Examples::
...@@ -1292,8 +1301,9 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1292,8 +1301,9 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``) **attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``.
Examples:: Examples::
......
...@@ -163,6 +163,18 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -163,6 +163,18 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_base_model_with_att_output(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetModel(config)
model.eval()
_, _, attentions = model(input_ids_1, target_mapping=target_mapping)
self.parent.assertEqual(len(attentions), config.n_layer)
self.parent.assertIsInstance(attentions[0], tuple)
self.parent.assertEqual(len(attentions[0]), 2)
self.parent.assertTrue(attentions[0][0].shape, attentions[0][0].shape)
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
model = XLNetLMHeadModel(config) model = XLNetLMHeadModel(config)
...@@ -306,6 +318,12 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -306,6 +318,12 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs) self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
def test_xlnet_base_model_with_att_output(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs[0].output_attentions = True
self.model_tester.create_and_check_xlnet_base_model_with_att_output(*config_and_inputs)
def test_xlnet_lm_head(self): def test_xlnet_lm_head(self):
self.model_tester.set_seed() self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment