Unverified Commit bad35839 authored by Jean Vancoppenolle's avatar Jean Vancoppenolle Committed by GitHub
Browse files

Add support for pretraining recurring span selection to Splinter (#17247)



* Add SplinterForSpanSelection for pre-training recurring span selection.

* Formatting.

* Rename SplinterForSpanSelection to SplinterForPreTraining.

* Ensure repo consistency

* Fixup changes

* Address SplinterForPreTraining PR comments

* Incorporate feedback and derive multiple question tokens per example.

* Update src/transformers/models/splinter/modeling_splinter.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/models/splinter/modeling_splinter.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarJean Vancoppenole <jean.vancoppenolle@retresco.de>
Co-authored-by: default avatarTobias Günther <tobias.guenther@retresco.de>
Co-authored-by: default avatarTobias Günther <github@tobigue.de>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 05113055
...@@ -72,3 +72,8 @@ This model was contributed by [yuvalkirstain](https://huggingface.co/yuvalkirsta ...@@ -72,3 +72,8 @@ This model was contributed by [yuvalkirstain](https://huggingface.co/yuvalkirsta
[[autodoc]] SplinterForQuestionAnswering [[autodoc]] SplinterForQuestionAnswering
- forward - forward
## SplinterForPreTraining
[[autodoc]] SplinterForPreTraining
- forward
...@@ -1532,6 +1532,7 @@ else: ...@@ -1532,6 +1532,7 @@ else:
_import_structure["models.splinter"].extend( _import_structure["models.splinter"].extend(
[ [
"SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST", "SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST",
"SplinterForPreTraining",
"SplinterForQuestionAnswering", "SplinterForQuestionAnswering",
"SplinterLayer", "SplinterLayer",
"SplinterModel", "SplinterModel",
...@@ -3830,6 +3831,7 @@ if TYPE_CHECKING: ...@@ -3830,6 +3831,7 @@ if TYPE_CHECKING:
from .models.speech_to_text_2 import Speech2Text2ForCausalLM, Speech2Text2PreTrainedModel from .models.speech_to_text_2 import Speech2Text2ForCausalLM, Speech2Text2PreTrainedModel
from .models.splinter import ( from .models.splinter import (
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST, SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST,
SplinterForPreTraining,
SplinterForQuestionAnswering, SplinterForQuestionAnswering,
SplinterLayer, SplinterLayer,
SplinterModel, SplinterModel,
......
...@@ -161,6 +161,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -161,6 +161,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("openai-gpt", "OpenAIGPTLMHeadModel"), ("openai-gpt", "OpenAIGPTLMHeadModel"),
("retribert", "RetriBertModel"), ("retribert", "RetriBertModel"),
("roberta", "RobertaForMaskedLM"), ("roberta", "RobertaForMaskedLM"),
("splinter", "SplinterForPreTraining"),
("squeezebert", "SqueezeBertForMaskedLM"), ("squeezebert", "SqueezeBertForMaskedLM"),
("t5", "T5ForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"),
("tapas", "TapasForMaskedLM"), ("tapas", "TapasForMaskedLM"),
......
...@@ -42,6 +42,7 @@ else: ...@@ -42,6 +42,7 @@ else:
_import_structure["modeling_splinter"] = [ _import_structure["modeling_splinter"] = [
"SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST", "SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST",
"SplinterForQuestionAnswering", "SplinterForQuestionAnswering",
"SplinterForPreTraining",
"SplinterLayer", "SplinterLayer",
"SplinterModel", "SplinterModel",
"SplinterPreTrainedModel", "SplinterPreTrainedModel",
...@@ -68,6 +69,7 @@ if TYPE_CHECKING: ...@@ -68,6 +69,7 @@ if TYPE_CHECKING:
else: else:
from .modeling_splinter import ( from .modeling_splinter import (
SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST, SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST,
SplinterForPreTraining,
SplinterForQuestionAnswering, SplinterForQuestionAnswering,
SplinterLayer, SplinterLayer,
SplinterModel, SplinterModel,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import math import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -24,7 +25,7 @@ from torch import nn ...@@ -24,7 +25,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, QuestionAnsweringModelOutput from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
...@@ -940,3 +941,171 @@ class SplinterForQuestionAnswering(SplinterPreTrainedModel): ...@@ -940,3 +941,171 @@ class SplinterForQuestionAnswering(SplinterPreTrainedModel):
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
@dataclass
class SplinterForPreTrainingOutput(ModelOutput):
"""
Class for outputs of Splinter as a span selection model.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when start and end positions are provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
Span-end scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
start_logits: torch.FloatTensor = None
end_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@add_start_docstrings(
"""
Splinter Model for the recurring span selection task as done during the pretraining. The difference to the QA task
is that we do not have a question, but multiple question tokens that replace the occurrences of recurring spans
instead.
""",
SPLINTER_START_DOCSTRING,
)
class SplinterForPreTraining(SplinterPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.splinter = SplinterModel(config)
self.splinter_qass = QuestionAwareSpanSelectionHead(config)
self.question_token_id = config.question_token_id
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(
SPLINTER_INPUTS_DOCSTRING.format("batch_size, num_questions, sequence_length")
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
question_positions: Optional[torch.LongTensor] = None,
) -> Union[Tuple, SplinterForPreTrainingOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size,
num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be
the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size,
sequence_length)`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if question_positions is None and start_positions is not None and end_positions is not None:
raise TypeError("question_positions must be specified in order to calculate the loss")
elif question_positions is None and input_ids is None:
raise TypeError("question_positions must be specified when input_embeds is used")
elif question_positions is None:
question_positions = self._prepare_question_positions(input_ids)
outputs = self.splinter(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
batch_size, sequence_length, dim = sequence_output.size()
# [batch_size, num_questions, sequence_length]
start_logits, end_logits = self.splinter_qass(sequence_output, question_positions)
num_questions = question_positions.size(1)
if attention_mask is not None:
attention_mask_for_each_question = attention_mask.unsqueeze(1).expand(
batch_size, num_questions, sequence_length
)
start_logits = start_logits + (1 - attention_mask_for_each_question) * -10000.0
end_logits = end_logits + (1 - attention_mask_for_each_question) * -10000.0
total_loss = None
# [batch_size, num_questions, sequence_length]
if start_positions is not None and end_positions is not None:
# sometimes the start/end positions are outside our model inputs, we ignore these terms
start_positions.clamp_(0, max(0, sequence_length - 1))
end_positions.clamp_(0, max(0, sequence_length - 1))
# Ignore zero positions in the loss. Splinter never predicts zero
# during pretraining and zero is used for padding question
# tokens as well as for start and end positions of padded
# question tokens.
loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)
start_loss = loss_fct(
start_logits.view(batch_size * num_questions, sequence_length),
start_positions.view(batch_size * num_questions),
)
end_loss = loss_fct(
end_logits.view(batch_size * num_questions, sequence_length),
end_positions.view(batch_size * num_questions),
)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output
return SplinterForPreTrainingOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def _prepare_question_positions(self, input_ids: torch.Tensor) -> torch.Tensor:
rows, flat_positions = torch.where(input_ids == self.config.question_token_id)
num_questions = torch.bincount(rows)
positions = torch.full(
(input_ids.size(0), num_questions.max()),
self.config.pad_token_id,
dtype=torch.long,
device=input_ids.device,
)
cols = torch.cat([torch.arange(n) for n in num_questions])
positions[rows, cols] = flat_positions
return positions
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Splinter model. """ """ Testing suite for the PyTorch Splinter model. """
import copy
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
...@@ -27,7 +27,7 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attenti ...@@ -27,7 +27,7 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attenti
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import SplinterConfig, SplinterForQuestionAnswering, SplinterModel from transformers import SplinterConfig, SplinterForPreTraining, SplinterForQuestionAnswering, SplinterModel
from transformers.models.splinter.modeling_splinter import SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.splinter.modeling_splinter import SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST
...@@ -36,6 +36,7 @@ class SplinterModelTester: ...@@ -36,6 +36,7 @@ class SplinterModelTester:
self, self,
parent, parent,
batch_size=13, batch_size=13,
num_questions=3,
seq_length=7, seq_length=7,
is_training=True, is_training=True,
use_input_mask=True, use_input_mask=True,
...@@ -43,6 +44,7 @@ class SplinterModelTester: ...@@ -43,6 +44,7 @@ class SplinterModelTester:
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
hidden_size=32, hidden_size=32,
question_token_id=1,
num_hidden_layers=5, num_hidden_layers=5,
num_attention_heads=4, num_attention_heads=4,
intermediate_size=37, intermediate_size=37,
...@@ -59,6 +61,7 @@ class SplinterModelTester: ...@@ -59,6 +61,7 @@ class SplinterModelTester:
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.num_questions = num_questions
self.seq_length = seq_length self.seq_length = seq_length
self.is_training = is_training self.is_training = is_training
self.use_input_mask = use_input_mask self.use_input_mask = use_input_mask
...@@ -66,6 +69,7 @@ class SplinterModelTester: ...@@ -66,6 +69,7 @@ class SplinterModelTester:
self.use_labels = use_labels self.use_labels = use_labels
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.question_token_id = question_token_id
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
...@@ -82,6 +86,7 @@ class SplinterModelTester: ...@@ -82,6 +86,7 @@ class SplinterModelTester:
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids[:, 1] = self.question_token_id
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
...@@ -91,13 +96,13 @@ class SplinterModelTester: ...@@ -91,13 +96,13 @@ class SplinterModelTester:
if self.use_token_type_ids: if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None start_positions = None
token_labels = None end_positions = None
choice_labels = None question_positions = None
if self.use_labels: if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) start_positions = ids_tensor([self.batch_size, self.num_questions], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) end_positions = ids_tensor([self.batch_size, self.num_questions], self.type_sequence_label_size)
choice_labels = ids_tensor([self.batch_size], self.num_choices) question_positions = ids_tensor([self.batch_size, self.num_questions], self.num_labels)
config = SplinterConfig( config = SplinterConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
...@@ -112,12 +117,20 @@ class SplinterModelTester: ...@@ -112,12 +117,20 @@ class SplinterModelTester:
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
question_token_id=self.question_token_id,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return (config, input_ids, token_type_ids, input_mask, start_positions, end_positions, question_positions)
def create_and_check_model( def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self,
config,
input_ids,
token_type_ids,
input_mask,
start_positions,
end_positions,
question_positions,
): ):
model = SplinterModel(config=config) model = SplinterModel(config=config)
model.to(torch_device) model.to(torch_device)
...@@ -128,7 +141,14 @@ class SplinterModelTester: ...@@ -128,7 +141,14 @@ class SplinterModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_question_answering( def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self,
config,
input_ids,
token_type_ids,
input_mask,
start_positions,
end_positions,
question_positions,
): ):
model = SplinterForQuestionAnswering(config=config) model = SplinterForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
...@@ -137,12 +157,36 @@ class SplinterModelTester: ...@@ -137,12 +157,36 @@ class SplinterModelTester:
input_ids, input_ids,
attention_mask=input_mask, attention_mask=input_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
start_positions=sequence_labels, start_positions=start_positions[:, 0],
end_positions=sequence_labels, end_positions=end_positions[:, 0],
) )
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_for_pretraining(
self,
config,
input_ids,
token_type_ids,
input_mask,
start_positions,
end_positions,
question_positions,
):
model = SplinterForPreTraining(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=start_positions,
end_positions=end_positions,
question_positions=question_positions,
)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.num_questions, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.num_questions, self.seq_length))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -150,11 +194,15 @@ class SplinterModelTester: ...@@ -150,11 +194,15 @@ class SplinterModelTester:
input_ids, input_ids,
token_type_ids, token_type_ids,
input_mask, input_mask,
sequence_labels, start_positions,
token_labels, end_positions,
choice_labels, question_positions,
) = config_and_inputs ) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} inputs_dict = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": input_mask,
}
return config, inputs_dict return config, inputs_dict
...@@ -165,11 +213,44 @@ class SplinterModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -165,11 +213,44 @@ class SplinterModelTest(ModelTesterMixin, unittest.TestCase):
( (
SplinterModel, SplinterModel,
SplinterForQuestionAnswering, SplinterForQuestionAnswering,
SplinterForPreTraining,
) )
if is_torch_available() if is_torch_available()
else () else ()
) )
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
if return_labels:
if issubclass(model_class, SplinterForPreTraining):
inputs_dict["start_positions"] = torch.zeros(
self.model_tester.batch_size,
self.model_tester.num_questions,
dtype=torch.long,
device=torch_device,
)
inputs_dict["end_positions"] = torch.zeros(
self.model_tester.batch_size,
self.model_tester.num_questions,
dtype=torch.long,
device=torch_device,
)
inputs_dict["question_positions"] = torch.zeros(
self.model_tester.batch_size,
self.model_tester.num_questions,
dtype=torch.long,
device=torch_device,
)
elif issubclass(model_class, SplinterForQuestionAnswering):
inputs_dict["start_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
inputs_dict["end_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
return inputs_dict
def setUp(self): def setUp(self):
self.model_tester = SplinterModelTester(self) self.model_tester = SplinterModelTester(self)
self.config_tester = ConfigTester(self, config_class=SplinterConfig, hidden_size=37) self.config_tester = ConfigTester(self, config_class=SplinterConfig, hidden_size=37)
...@@ -191,6 +272,44 @@ class SplinterModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -191,6 +272,44 @@ class SplinterModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
del inputs["input_ids"]
else:
encoder_input_ids = inputs["input_ids"]
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
del inputs["input_ids"]
inputs.pop("decoder_input_ids", None)
wte = model.get_input_embeddings()
if not self.is_encoder_decoder:
inputs["inputs_embeds"] = wte(input_ids)
else:
inputs["inputs_embeds"] = wte(encoder_input_ids)
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
with torch.no_grad():
if isinstance(model, SplinterForPreTraining):
with self.assertRaises(TypeError):
# question_positions must not be None.
model(**inputs)[0]
else:
model(**inputs)[0]
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
...@@ -217,3 +336,122 @@ class SplinterModelIntegrationTest(unittest.TestCase): ...@@ -217,3 +336,122 @@ class SplinterModelIntegrationTest(unittest.TestCase):
self.assertEqual(torch.argmax(output.start_logits), 10) self.assertEqual(torch.argmax(output.start_logits), 10)
self.assertEqual(torch.argmax(output.end_logits), 12) self.assertEqual(torch.argmax(output.end_logits), 12)
@slow
def test_splinter_pretraining(self):
model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
# Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]"
# Output should be the spans "Brad" and "the United Kingdom"
input_ids = torch.tensor(
[[101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102]]
)
question_positions = torch.tensor([[1, 5]], dtype=torch.long)
output = model(input_ids, question_positions=question_positions)
expected_shape = torch.Size((1, 2, 16))
self.assertEqual(output.start_logits.shape, expected_shape)
self.assertEqual(output.end_logits.shape, expected_shape)
self.assertEqual(torch.argmax(output.start_logits[0, 0]), 7)
self.assertEqual(torch.argmax(output.end_logits[0, 0]), 7)
self.assertEqual(torch.argmax(output.start_logits[0, 1]), 10)
self.assertEqual(torch.argmax(output.end_logits[0, 1]), 12)
@slow
def test_splinter_pretraining_loss_requires_question_positions(self):
model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
# Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]"
# Output should be the spans "Brad" and "the United Kingdom"
input_ids = torch.tensor(
[[101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102]]
)
start_positions = torch.tensor([[7, 10]], dtype=torch.long)
end_positions = torch.tensor([7, 12], dtype=torch.long)
with self.assertRaises(TypeError):
model(
input_ids,
start_positions=start_positions,
end_positions=end_positions,
)
@slow
def test_splinter_pretraining_loss(self):
model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
# Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]"
# Output should be the spans "Brad" and "the United Kingdom"
input_ids = torch.tensor(
[
[101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102],
[101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102],
]
)
start_positions = torch.tensor([[7, 10], [7, 10]], dtype=torch.long)
end_positions = torch.tensor([[7, 12], [7, 12]], dtype=torch.long)
question_positions = torch.tensor([[1, 5], [1, 5]], dtype=torch.long)
output = model(
input_ids,
start_positions=start_positions,
end_positions=end_positions,
question_positions=question_positions,
)
self.assertAlmostEqual(output.loss.item(), 0.0024, 4)
@slow
def test_splinter_pretraining_loss_with_padding(self):
model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
# Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]"
# Output should be the spans "Brad" and "the United Kingdom"
input_ids = torch.tensor(
[
[101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102],
]
)
start_positions = torch.tensor([[7, 10]], dtype=torch.long)
end_positions = torch.tensor([7, 12], dtype=torch.long)
question_positions = torch.tensor([[1, 5]], dtype=torch.long)
start_positions_with_padding = torch.tensor([[7, 10, 0]], dtype=torch.long)
end_positions_with_padding = torch.tensor([7, 12, 0], dtype=torch.long)
question_positions_with_padding = torch.tensor([[1, 5, 0]], dtype=torch.long)
output = model(
input_ids,
start_positions=start_positions,
end_positions=end_positions,
question_positions=question_positions,
)
output_with_padding = model(
input_ids,
start_positions=start_positions_with_padding,
end_positions=end_positions_with_padding,
question_positions=question_positions_with_padding,
)
self.assertAlmostEqual(output.loss.item(), output_with_padding.loss.item(), 4)
# Note that the original code uses 0 to denote padded question tokens
# and their start and end positions. As the pad_token_id of the model's
# config is used for the losse's ignore_index in SplinterForPreTraining,
# we add this test to ensure anybody making changes to the default
# value of the config, will be aware of the implication.
self.assertEqual(model.config.pad_token_id, 0)
@slow
def test_splinter_pretraining_prepare_question_positions(self):
model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass")
input_ids = torch.tensor(
[
[101, 104, 1, 2, 104, 3, 4, 102],
[101, 1, 104, 2, 104, 3, 104, 102],
[101, 1, 2, 104, 104, 3, 4, 102],
[101, 1, 2, 3, 4, 5, 104, 102],
]
)
question_positions = torch.tensor([[1, 4, 0], [2, 4, 6], [3, 4, 0], [6, 0, 0]], dtype=torch.long)
output_without_positions = model(input_ids)
output_with_positions = model(input_ids, question_positions=question_positions)
self.assertTrue((output_without_positions.start_logits == output_with_positions.start_logits).all())
self.assertTrue((output_without_positions.end_logits == output_with_positions.end_logits).all())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment