Unverified Commit 62098b93 authored by Ikuya Yamada's avatar Ikuya Yamada Committed by GitHub
Browse files

Adding fine-tuning models to LUKE (#18353)

* add LUKE models for downstream tasks

* add new LUKE models to docs

* fix typos

* remove commented lines

* exclude None items from tuple return values
parent 7b9e995b
......@@ -152,3 +152,23 @@ This model was contributed by [ikuyamada](https://huggingface.co/ikuyamada) and
[[autodoc]] LukeForEntitySpanClassification
- forward
## LukeForSequenceClassification
[[autodoc]] LukeForSequenceClassification
- forward
## LukeForMultipleChoice
[[autodoc]] LukeForMultipleChoice
- forward
## LukeForTokenClassification
[[autodoc]] LukeForTokenClassification
- forward
## LukeForQuestionAnswering
[[autodoc]] LukeForQuestionAnswering
- forward
......@@ -1363,6 +1363,10 @@ else:
"LukeForEntityClassification",
"LukeForEntityPairClassification",
"LukeForEntitySpanClassification",
"LukeForMultipleChoice",
"LukeForQuestionAnswering",
"LukeForSequenceClassification",
"LukeForTokenClassification",
"LukeForMaskedLM",
"LukeModel",
"LukePreTrainedModel",
......@@ -3953,6 +3957,10 @@ if TYPE_CHECKING:
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForMaskedLM,
LukeForMultipleChoice,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeModel,
LukePreTrainedModel,
)
......
......@@ -170,6 +170,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("ibert", "IBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"),
("longformer", "LongformerForMaskedLM"),
("luke", "LukeForMaskedLM"),
("lxmert", "LxmertForPreTraining"),
("megatron-bert", "MegatronBertForPreTraining"),
("mobilebert", "MobileBertForPreTraining"),
......@@ -230,6 +231,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("led", "LEDForConditionalGeneration"),
("longformer", "LongformerForMaskedLM"),
("longt5", "LongT5ForConditionalGeneration"),
("luke", "LukeForMaskedLM"),
("m2m_100", "M2M100ForConditionalGeneration"),
("marian", "MarianMTModel"),
("megatron-bert", "MegatronBertForCausalLM"),
......@@ -499,6 +501,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
("led", "LEDForSequenceClassification"),
("longformer", "LongformerForSequenceClassification"),
("luke", "LukeForSequenceClassification"),
("mbart", "MBartForSequenceClassification"),
("megatron-bert", "MegatronBertForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"),
......@@ -551,6 +554,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
("led", "LEDForQuestionAnswering"),
("longformer", "LongformerForQuestionAnswering"),
("luke", "LukeForQuestionAnswering"),
("lxmert", "LxmertForQuestionAnswering"),
("mbart", "MBartForQuestionAnswering"),
("megatron-bert", "MegatronBertForQuestionAnswering"),
......@@ -611,6 +615,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
("layoutlmv3", "LayoutLMv3ForTokenClassification"),
("longformer", "LongformerForTokenClassification"),
("luke", "LukeForTokenClassification"),
("megatron-bert", "MegatronBertForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"),
("mpnet", "MPNetForTokenClassification"),
......@@ -647,6 +652,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
("funnel", "FunnelForMultipleChoice"),
("ibert", "IBertForMultipleChoice"),
("longformer", "LongformerForMultipleChoice"),
("luke", "LukeForMultipleChoice"),
("megatron-bert", "MegatronBertForMultipleChoice"),
("mobilebert", "MobileBertForMultipleChoice"),
("mpnet", "MPNetForMultipleChoice"),
......
......@@ -37,6 +37,10 @@ else:
"LukeForEntityClassification",
"LukeForEntityPairClassification",
"LukeForEntitySpanClassification",
"LukeForMultipleChoice",
"LukeForQuestionAnswering",
"LukeForSequenceClassification",
"LukeForTokenClassification",
"LukeForMaskedLM",
"LukeModel",
"LukePreTrainedModel",
......@@ -59,6 +63,10 @@ if TYPE_CHECKING:
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForMaskedLM,
LukeForMultipleChoice,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeModel,
LukePreTrainedModel,
)
......
......@@ -74,6 +74,8 @@ class LukeConfig(PretrainedConfig):
Whether or not the model should use the entity-aware self-attention mechanism proposed in [LUKE: Deep
Contextualized Entity Representations with Entity-aware Self-attention (Yamada et
al.)](https://arxiv.org/abs/2010.01057).
classifier_dropout (`float`, *optional*):
The dropout ratio for the classification head.
Examples:
......@@ -108,6 +110,7 @@ class LukeConfig(PretrainedConfig):
initializer_range=0.02,
layer_norm_eps=1e-12,
use_entity_aware_attention=True,
classifier_dropout=None,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
......@@ -131,3 +134,4 @@ class LukeConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.use_entity_aware_attention = use_entity_aware_attention
self.classifier_dropout = classifier_dropout
......@@ -21,6 +21,7 @@ from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
......@@ -28,6 +29,7 @@ from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
......@@ -247,6 +249,147 @@ class EntitySpanClassificationOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LukeSequenceClassifierOutput(ModelOutput):
"""
Outputs of sentence classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
layer plus the initial entity embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LukeTokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
Classification loss.
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
layer plus the initial entity embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LukeQuestionAnsweringModelOutput(ModelOutput):
"""
Outputs of question answering models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Span-end scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
layer plus the initial entity embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
start_logits: torch.FloatTensor = None
end_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LukeMultipleChoiceModelOutput(ModelOutput):
"""
Outputs of multiple choice models.
Args:
loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
Classification loss.
logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
*num_choices* is the second dimension of the input tensors. (see *input_ids* above).
Classification scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
layer plus the initial entity embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class LukeEmbeddings(nn.Module):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
......@@ -1240,15 +1383,20 @@ class LukeForMaskedLM(LukePreTrainedModel):
loss = loss + mep_loss
if not return_dict:
output = (logits, entity_logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions)
if mlm_loss is not None and mep_loss is not None:
return (loss, mlm_loss, mep_loss) + output
elif mlm_loss is not None:
return (loss, mlm_loss) + output
elif mep_loss is not None:
return (loss, mep_loss) + output
else:
return output
return tuple(
v
for v in [
loss,
mlm_loss,
mep_loss,
logits,
entity_logits,
outputs.hidden_states,
outputs.entity_hidden_states,
outputs.attentions,
]
if v is not None
)
return LukeMaskedLMOutput(
loss=loss,
......@@ -1360,13 +1508,11 @@ class LukeForEntityClassification(LukePreTrainedModel):
loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
if not return_dict:
output = (
logits,
outputs.hidden_states,
outputs.entity_hidden_states,
outputs.attentions,
return tuple(
v
for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
if v is not None
)
return ((loss,) + output) if loss is not None else output
return EntityClassificationOutput(
loss=loss,
......@@ -1480,13 +1626,11 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
if not return_dict:
output = (
logits,
outputs.hidden_states,
outputs.entity_hidden_states,
outputs.attentions,
return tuple(
v
for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
if v is not None
)
return ((loss,) + output) if loss is not None else output
return EntityPairClassificationOutput(
loss=loss,
......@@ -1620,13 +1764,11 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
if not return_dict:
output = (
logits,
outputs.hidden_states,
outputs.entity_hidden_states,
outputs.attentions,
return tuple(
v
for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
if v is not None
)
return ((loss,) + output) if loss is not None else output
return EntitySpanClassificationOutput(
loss=loss,
......@@ -1635,3 +1777,460 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
entity_hidden_states=outputs.entity_hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The LUKE Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
LUKE_START_DOCSTRING,
)
class LukeForSequenceClassification(LukePreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.luke = LukeModel(config)
self.dropout = nn.Dropout(
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=LukeSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask: Optional[torch.FloatTensor] = None,
entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, LukeSequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.luke(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
return tuple(
v
for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
if v is not None
)
return LukeSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
entity_hidden_states=outputs.entity_hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The LUKE Model with a token classification head on top (a linear layer on top of the hidden-states output). To
solve Named-Entity Recognition (NER) task using LUKE, `LukeForEntitySpanClassification` is more suitable than this
class.
""",
LUKE_START_DOCSTRING,
)
class LukeForTokenClassification(LukePreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.luke = LukeModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=LukeTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask: Optional[torch.FloatTensor] = None,
entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, LukeTokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
`input_ids` above)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.luke(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
sequence_output = outputs.last_hidden_state
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
return tuple(
v
for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
if v is not None
)
return LukeTokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
entity_hidden_states=outputs.entity_hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The LUKE Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
LUKE_START_DOCSTRING,
)
class LukeForQuestionAnswering(LukePreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.luke = LukeModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=LukeQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.FloatTensor] = None,
entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask: Optional[torch.FloatTensor] = None,
entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, LukeQuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.luke(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
sequence_output = outputs.last_hidden_state
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
return tuple(
v
for v in [
total_loss,
start_logits,
end_logits,
outputs.hidden_states,
outputs.entity_hidden_states,
outputs.attentions,
]
if v is not None
)
return LukeQuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
entity_hidden_states=outputs.entity_hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The LUKE Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
LUKE_START_DOCSTRING,
)
class LukeForMultipleChoice(LukePreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.luke = LukeModel(config)
self.dropout = nn.Dropout(
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.classifier = nn.Linear(config.hidden_size, 1)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=LukeMultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask: Optional[torch.FloatTensor] = None,
entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, LukeMultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
`input_ids` above)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)
entity_ids = entity_ids.view(-1, entity_ids.size(-1)) if entity_ids is not None else None
entity_attention_mask = (
entity_attention_mask.view(-1, entity_attention_mask.size(-1))
if entity_attention_mask is not None
else None
)
entity_token_type_ids = (
entity_token_type_ids.view(-1, entity_token_type_ids.size(-1))
if entity_token_type_ids is not None
else None
)
entity_position_ids = (
entity_position_ids.view(-1, entity_position_ids.size(-2), entity_position_ids.size(-1))
if entity_position_ids is not None
else None
)
outputs = self.luke(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
if not return_dict:
return tuple(
v
for v in [
loss,
reshaped_logits,
outputs.hidden_states,
outputs.entity_hidden_states,
outputs.attentions,
]
if v is not None
)
return LukeMultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
entity_hidden_states=outputs.entity_hidden_states,
attentions=outputs.attentions,
)
......@@ -2736,6 +2736,34 @@ class LukeForMaskedLM(metaclass=DummyObject):
requires_backends(self, ["torch"])
class LukeForMultipleChoice(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeModel(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -30,6 +30,10 @@ if is_torch_available():
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForMaskedLM,
LukeForMultipleChoice,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeModel,
LukeTokenizer,
)
......@@ -66,6 +70,8 @@ class LukeModelTester:
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
num_entity_classification_labels=9,
num_entity_pair_classification_labels=6,
num_entity_span_classification_labels=4,
......@@ -99,6 +105,8 @@ class LukeModelTester:
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.num_entity_classification_labels = num_entity_classification_labels
self.num_entity_pair_classification_labels = num_entity_pair_classification_labels
self.num_entity_span_classification_labels = num_entity_span_classification_labels
......@@ -139,7 +147,8 @@ class LukeModelTester:
)
sequence_labels = None
labels = None
token_labels = None
choice_labels = None
entity_labels = None
entity_classification_labels = None
entity_pair_classification_labels = None
......@@ -147,7 +156,9 @@ class LukeModelTester:
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size)
entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels)
......@@ -170,7 +181,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
......@@ -207,7 +219,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
......@@ -247,7 +260,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
......@@ -266,7 +280,7 @@ class LukeModelTester:
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=labels,
labels=token_labels,
entity_labels=entity_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
......@@ -288,7 +302,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
......@@ -322,7 +337,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
......@@ -356,7 +372,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
......@@ -386,6 +403,156 @@ class LukeModelTester:
result.logits.shape, (self.batch_size, self.entity_length, self.num_entity_span_classification_labels)
)
def create_and_check_for_question_answering(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
model = LukeForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_for_sequence_classification(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_labels = self.num_labels
model = LukeForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=sequence_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_for_token_classification(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_labels = self.num_labels
model = LukeForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=token_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_for_multiple_choice(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_choices = self.num_choices
model = LukeForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_attention_mask = attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_entity_ids = entity_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_entity_token_type_ids = (
entity_token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
)
multiple_choice_entity_attention_mask = (
entity_attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
)
multiple_choice_entity_position_ids = (
entity_position_ids.unsqueeze(1).expand(-1, self.num_choices, -1, -1).contiguous()
)
result = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_attention_mask,
token_type_ids=multiple_choice_token_type_ids,
entity_ids=multiple_choice_entity_ids,
entity_attention_mask=multiple_choice_entity_attention_mask,
entity_token_type_ids=multiple_choice_entity_token_type_ids,
entity_position_ids=multiple_choice_entity_position_ids,
labels=choice_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
......@@ -398,7 +565,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
......@@ -426,6 +594,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeForMultipleChoice,
)
if is_torch_available()
else ()
......@@ -436,7 +608,19 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
test_head_masking = True
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
entity_inputs_dict = {k: v for k, v in inputs_dict.items() if k.startswith("entity")}
inputs_dict = {k: v for k, v in inputs_dict.items() if not k.startswith("entity")}
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if model_class == LukeForMultipleChoice:
entity_inputs_dict = {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if v.ndim == 2
else v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1, -1).contiguous()
for k, v in entity_inputs_dict.items()
}
inputs_dict.update(entity_inputs_dict)
if model_class == LukeForEntitySpanClassification:
inputs_dict["entity_start_positions"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.entity_length), dtype=torch.long, device=torch_device
......@@ -446,7 +630,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
)
if return_labels:
if model_class in (LukeForEntityClassification, LukeForEntityPairClassification):
if model_class in (
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForSequenceClassification,
LukeForMultipleChoice,
):
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
......@@ -456,6 +645,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
dtype=torch.long,
device=torch_device,
)
elif model_class == LukeForTokenClassification:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length),
dtype=torch.long,
device=torch_device,
)
elif model_class == LukeForMaskedLM:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length),
......@@ -496,6 +691,22 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:])))
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
def test_for_entity_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment