Unverified Commit 62098b93 authored by Ikuya Yamada's avatar Ikuya Yamada Committed by GitHub
Browse files

Adding fine-tuning models to LUKE (#18353)

* add LUKE models for downstream tasks

* add new LUKE models to docs

* fix typos

* remove commented lines

* exclude None items from tuple return values
parent 7b9e995b
...@@ -152,3 +152,23 @@ This model was contributed by [ikuyamada](https://huggingface.co/ikuyamada) and ...@@ -152,3 +152,23 @@ This model was contributed by [ikuyamada](https://huggingface.co/ikuyamada) and
[[autodoc]] LukeForEntitySpanClassification [[autodoc]] LukeForEntitySpanClassification
- forward - forward
## LukeForSequenceClassification
[[autodoc]] LukeForSequenceClassification
- forward
## LukeForMultipleChoice
[[autodoc]] LukeForMultipleChoice
- forward
## LukeForTokenClassification
[[autodoc]] LukeForTokenClassification
- forward
## LukeForQuestionAnswering
[[autodoc]] LukeForQuestionAnswering
- forward
...@@ -1363,6 +1363,10 @@ else: ...@@ -1363,6 +1363,10 @@ else:
"LukeForEntityClassification", "LukeForEntityClassification",
"LukeForEntityPairClassification", "LukeForEntityPairClassification",
"LukeForEntitySpanClassification", "LukeForEntitySpanClassification",
"LukeForMultipleChoice",
"LukeForQuestionAnswering",
"LukeForSequenceClassification",
"LukeForTokenClassification",
"LukeForMaskedLM", "LukeForMaskedLM",
"LukeModel", "LukeModel",
"LukePreTrainedModel", "LukePreTrainedModel",
...@@ -3953,6 +3957,10 @@ if TYPE_CHECKING: ...@@ -3953,6 +3957,10 @@ if TYPE_CHECKING:
LukeForEntityPairClassification, LukeForEntityPairClassification,
LukeForEntitySpanClassification, LukeForEntitySpanClassification,
LukeForMaskedLM, LukeForMaskedLM,
LukeForMultipleChoice,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeModel, LukeModel,
LukePreTrainedModel, LukePreTrainedModel,
) )
......
...@@ -170,6 +170,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -170,6 +170,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("ibert", "IBertForMaskedLM"), ("ibert", "IBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"), ("layoutlm", "LayoutLMForMaskedLM"),
("longformer", "LongformerForMaskedLM"), ("longformer", "LongformerForMaskedLM"),
("luke", "LukeForMaskedLM"),
("lxmert", "LxmertForPreTraining"), ("lxmert", "LxmertForPreTraining"),
("megatron-bert", "MegatronBertForPreTraining"), ("megatron-bert", "MegatronBertForPreTraining"),
("mobilebert", "MobileBertForPreTraining"), ("mobilebert", "MobileBertForPreTraining"),
...@@ -230,6 +231,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( ...@@ -230,6 +231,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("led", "LEDForConditionalGeneration"), ("led", "LEDForConditionalGeneration"),
("longformer", "LongformerForMaskedLM"), ("longformer", "LongformerForMaskedLM"),
("longt5", "LongT5ForConditionalGeneration"), ("longt5", "LongT5ForConditionalGeneration"),
("luke", "LukeForMaskedLM"),
("m2m_100", "M2M100ForConditionalGeneration"), ("m2m_100", "M2M100ForConditionalGeneration"),
("marian", "MarianMTModel"), ("marian", "MarianMTModel"),
("megatron-bert", "MegatronBertForCausalLM"), ("megatron-bert", "MegatronBertForCausalLM"),
...@@ -499,6 +501,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -499,6 +501,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("layoutlmv3", "LayoutLMv3ForSequenceClassification"), ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
("led", "LEDForSequenceClassification"), ("led", "LEDForSequenceClassification"),
("longformer", "LongformerForSequenceClassification"), ("longformer", "LongformerForSequenceClassification"),
("luke", "LukeForSequenceClassification"),
("mbart", "MBartForSequenceClassification"), ("mbart", "MBartForSequenceClassification"),
("megatron-bert", "MegatronBertForSequenceClassification"), ("megatron-bert", "MegatronBertForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"), ("mobilebert", "MobileBertForSequenceClassification"),
...@@ -551,6 +554,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ...@@ -551,6 +554,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
("led", "LEDForQuestionAnswering"), ("led", "LEDForQuestionAnswering"),
("longformer", "LongformerForQuestionAnswering"), ("longformer", "LongformerForQuestionAnswering"),
("luke", "LukeForQuestionAnswering"),
("lxmert", "LxmertForQuestionAnswering"), ("lxmert", "LxmertForQuestionAnswering"),
("mbart", "MBartForQuestionAnswering"), ("mbart", "MBartForQuestionAnswering"),
("megatron-bert", "MegatronBertForQuestionAnswering"), ("megatron-bert", "MegatronBertForQuestionAnswering"),
...@@ -611,6 +615,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -611,6 +615,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("layoutlmv2", "LayoutLMv2ForTokenClassification"), ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
("layoutlmv3", "LayoutLMv3ForTokenClassification"), ("layoutlmv3", "LayoutLMv3ForTokenClassification"),
("longformer", "LongformerForTokenClassification"), ("longformer", "LongformerForTokenClassification"),
("luke", "LukeForTokenClassification"),
("megatron-bert", "MegatronBertForTokenClassification"), ("megatron-bert", "MegatronBertForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"),
("mpnet", "MPNetForTokenClassification"), ("mpnet", "MPNetForTokenClassification"),
...@@ -647,6 +652,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( ...@@ -647,6 +652,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
("funnel", "FunnelForMultipleChoice"), ("funnel", "FunnelForMultipleChoice"),
("ibert", "IBertForMultipleChoice"), ("ibert", "IBertForMultipleChoice"),
("longformer", "LongformerForMultipleChoice"), ("longformer", "LongformerForMultipleChoice"),
("luke", "LukeForMultipleChoice"),
("megatron-bert", "MegatronBertForMultipleChoice"), ("megatron-bert", "MegatronBertForMultipleChoice"),
("mobilebert", "MobileBertForMultipleChoice"), ("mobilebert", "MobileBertForMultipleChoice"),
("mpnet", "MPNetForMultipleChoice"), ("mpnet", "MPNetForMultipleChoice"),
......
...@@ -37,6 +37,10 @@ else: ...@@ -37,6 +37,10 @@ else:
"LukeForEntityClassification", "LukeForEntityClassification",
"LukeForEntityPairClassification", "LukeForEntityPairClassification",
"LukeForEntitySpanClassification", "LukeForEntitySpanClassification",
"LukeForMultipleChoice",
"LukeForQuestionAnswering",
"LukeForSequenceClassification",
"LukeForTokenClassification",
"LukeForMaskedLM", "LukeForMaskedLM",
"LukeModel", "LukeModel",
"LukePreTrainedModel", "LukePreTrainedModel",
...@@ -59,6 +63,10 @@ if TYPE_CHECKING: ...@@ -59,6 +63,10 @@ if TYPE_CHECKING:
LukeForEntityPairClassification, LukeForEntityPairClassification,
LukeForEntitySpanClassification, LukeForEntitySpanClassification,
LukeForMaskedLM, LukeForMaskedLM,
LukeForMultipleChoice,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeModel, LukeModel,
LukePreTrainedModel, LukePreTrainedModel,
) )
......
...@@ -74,6 +74,8 @@ class LukeConfig(PretrainedConfig): ...@@ -74,6 +74,8 @@ class LukeConfig(PretrainedConfig):
Whether or not the model should use the entity-aware self-attention mechanism proposed in [LUKE: Deep Whether or not the model should use the entity-aware self-attention mechanism proposed in [LUKE: Deep
Contextualized Entity Representations with Entity-aware Self-attention (Yamada et Contextualized Entity Representations with Entity-aware Self-attention (Yamada et
al.)](https://arxiv.org/abs/2010.01057). al.)](https://arxiv.org/abs/2010.01057).
classifier_dropout (`float`, *optional*):
The dropout ratio for the classification head.
Examples: Examples:
...@@ -108,6 +110,7 @@ class LukeConfig(PretrainedConfig): ...@@ -108,6 +110,7 @@ class LukeConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
use_entity_aware_attention=True, use_entity_aware_attention=True,
classifier_dropout=None,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
...@@ -131,3 +134,4 @@ class LukeConfig(PretrainedConfig): ...@@ -131,3 +134,4 @@ class LukeConfig(PretrainedConfig):
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.use_entity_aware_attention = use_entity_aware_attention self.use_entity_aware_attention = use_entity_aware_attention
self.classifier_dropout = classifier_dropout
...@@ -21,6 +21,7 @@ from typing import Optional, Tuple, Union ...@@ -21,6 +21,7 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu from ...activations import ACT2FN, gelu
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
...@@ -28,6 +29,7 @@ from ...modeling_utils import PreTrainedModel ...@@ -28,6 +29,7 @@ from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward from ...pytorch_utils import apply_chunking_to_forward
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
logging, logging,
...@@ -247,6 +249,147 @@ class EntitySpanClassificationOutput(ModelOutput): ...@@ -247,6 +249,147 @@ class EntitySpanClassificationOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LukeSequenceClassifierOutput(ModelOutput):
"""
Outputs of sentence classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
layer plus the initial entity embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LukeTokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
Classification loss.
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
layer plus the initial entity embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LukeQuestionAnsweringModelOutput(ModelOutput):
"""
Outputs of question answering models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Span-end scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
layer plus the initial entity embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
start_logits: torch.FloatTensor = None
end_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LukeMultipleChoiceModelOutput(ModelOutput):
"""
Outputs of multiple choice models.
Args:
loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
Classification loss.
logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
*num_choices* is the second dimension of the input tensors. (see *input_ids* above).
Classification scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
layer plus the initial entity embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class LukeEmbeddings(nn.Module): class LukeEmbeddings(nn.Module):
""" """
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
...@@ -1240,15 +1383,20 @@ class LukeForMaskedLM(LukePreTrainedModel): ...@@ -1240,15 +1383,20 @@ class LukeForMaskedLM(LukePreTrainedModel):
loss = loss + mep_loss loss = loss + mep_loss
if not return_dict: if not return_dict:
output = (logits, entity_logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions) return tuple(
if mlm_loss is not None and mep_loss is not None: v
return (loss, mlm_loss, mep_loss) + output for v in [
elif mlm_loss is not None: loss,
return (loss, mlm_loss) + output mlm_loss,
elif mep_loss is not None: mep_loss,
return (loss, mep_loss) + output logits,
else: entity_logits,
return output outputs.hidden_states,
outputs.entity_hidden_states,
outputs.attentions,
]
if v is not None
)
return LukeMaskedLMOutput( return LukeMaskedLMOutput(
loss=loss, loss=loss,
...@@ -1360,13 +1508,11 @@ class LukeForEntityClassification(LukePreTrainedModel): ...@@ -1360,13 +1508,11 @@ class LukeForEntityClassification(LukePreTrainedModel):
loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
if not return_dict: if not return_dict:
output = ( return tuple(
logits, v
outputs.hidden_states, for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
outputs.entity_hidden_states, if v is not None
outputs.attentions,
) )
return ((loss,) + output) if loss is not None else output
return EntityClassificationOutput( return EntityClassificationOutput(
loss=loss, loss=loss,
...@@ -1480,13 +1626,11 @@ class LukeForEntityPairClassification(LukePreTrainedModel): ...@@ -1480,13 +1626,11 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
if not return_dict: if not return_dict:
output = ( return tuple(
logits, v
outputs.hidden_states, for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
outputs.entity_hidden_states, if v is not None
outputs.attentions,
) )
return ((loss,) + output) if loss is not None else output
return EntityPairClassificationOutput( return EntityPairClassificationOutput(
loss=loss, loss=loss,
...@@ -1620,17 +1764,472 @@ class LukeForEntitySpanClassification(LukePreTrainedModel): ...@@ -1620,17 +1764,472 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) loss = nn.functional.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits))
if not return_dict: if not return_dict:
output = ( return tuple(
logits, v
for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
if v is not None
)
return EntitySpanClassificationOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
entity_hidden_states=outputs.entity_hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The LUKE Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
LUKE_START_DOCSTRING,
)
class LukeForSequenceClassification(LukePreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.luke = LukeModel(config)
self.dropout = nn.Dropout(
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=LukeSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask: Optional[torch.FloatTensor] = None,
entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, LukeSequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.luke(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
return tuple(
v
for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
if v is not None
)
return LukeSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
entity_hidden_states=outputs.entity_hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The LUKE Model with a token classification head on top (a linear layer on top of the hidden-states output). To
solve Named-Entity Recognition (NER) task using LUKE, `LukeForEntitySpanClassification` is more suitable than this
class.
""",
LUKE_START_DOCSTRING,
)
class LukeForTokenClassification(LukePreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.luke = LukeModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=LukeTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask: Optional[torch.FloatTensor] = None,
entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, LukeTokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
`input_ids` above)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.luke(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
sequence_output = outputs.last_hidden_state
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
return tuple(
v
for v in [loss, logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions]
if v is not None
)
return LukeTokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
entity_hidden_states=outputs.entity_hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The LUKE Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
LUKE_START_DOCSTRING,
)
class LukeForQuestionAnswering(LukePreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.luke = LukeModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=LukeQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.FloatTensor] = None,
entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask: Optional[torch.FloatTensor] = None,
entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, LukeQuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.luke(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
sequence_output = outputs.last_hidden_state
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
return tuple(
v
for v in [
total_loss,
start_logits,
end_logits,
outputs.hidden_states, outputs.hidden_states,
outputs.entity_hidden_states, outputs.entity_hidden_states,
outputs.attentions, outputs.attentions,
]
if v is not None
) )
return ((loss,) + output) if loss is not None else output
return EntitySpanClassificationOutput( return LukeQuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
entity_hidden_states=outputs.entity_hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The LUKE Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
LUKE_START_DOCSTRING,
)
class LukeForMultipleChoice(LukePreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.luke = LukeModel(config)
self.dropout = nn.Dropout(
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.classifier = nn.Linear(config.hidden_size, 1)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=LukeMultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask: Optional[torch.FloatTensor] = None,
entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, LukeMultipleChoiceModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
`input_ids` above)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)
entity_ids = entity_ids.view(-1, entity_ids.size(-1)) if entity_ids is not None else None
entity_attention_mask = (
entity_attention_mask.view(-1, entity_attention_mask.size(-1))
if entity_attention_mask is not None
else None
)
entity_token_type_ids = (
entity_token_type_ids.view(-1, entity_token_type_ids.size(-1))
if entity_token_type_ids is not None
else None
)
entity_position_ids = (
entity_position_ids.view(-1, entity_position_ids.size(-2), entity_position_ids.size(-1))
if entity_position_ids is not None
else None
)
outputs = self.luke(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
if not return_dict:
return tuple(
v
for v in [
loss,
reshaped_logits,
outputs.hidden_states,
outputs.entity_hidden_states,
outputs.attentions,
]
if v is not None
)
return LukeMultipleChoiceModelOutput(
loss=loss, loss=loss,
logits=logits, logits=reshaped_logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
entity_hidden_states=outputs.entity_hidden_states, entity_hidden_states=outputs.entity_hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
......
...@@ -2736,6 +2736,34 @@ class LukeForMaskedLM(metaclass=DummyObject): ...@@ -2736,6 +2736,34 @@ class LukeForMaskedLM(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class LukeForMultipleChoice(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeModel(metaclass=DummyObject): class LukeModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -30,6 +30,10 @@ if is_torch_available(): ...@@ -30,6 +30,10 @@ if is_torch_available():
LukeForEntityPairClassification, LukeForEntityPairClassification,
LukeForEntitySpanClassification, LukeForEntitySpanClassification,
LukeForMaskedLM, LukeForMaskedLM,
LukeForMultipleChoice,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeModel, LukeModel,
LukeTokenizer, LukeTokenizer,
) )
...@@ -66,6 +70,8 @@ class LukeModelTester: ...@@ -66,6 +70,8 @@ class LukeModelTester:
type_vocab_size=16, type_vocab_size=16,
type_sequence_label_size=2, type_sequence_label_size=2,
initializer_range=0.02, initializer_range=0.02,
num_labels=3,
num_choices=4,
num_entity_classification_labels=9, num_entity_classification_labels=9,
num_entity_pair_classification_labels=6, num_entity_pair_classification_labels=6,
num_entity_span_classification_labels=4, num_entity_span_classification_labels=4,
...@@ -99,6 +105,8 @@ class LukeModelTester: ...@@ -99,6 +105,8 @@ class LukeModelTester:
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.num_entity_classification_labels = num_entity_classification_labels self.num_entity_classification_labels = num_entity_classification_labels
self.num_entity_pair_classification_labels = num_entity_pair_classification_labels self.num_entity_pair_classification_labels = num_entity_pair_classification_labels
self.num_entity_span_classification_labels = num_entity_span_classification_labels self.num_entity_span_classification_labels = num_entity_span_classification_labels
...@@ -139,7 +147,8 @@ class LukeModelTester: ...@@ -139,7 +147,8 @@ class LukeModelTester:
) )
sequence_labels = None sequence_labels = None
labels = None token_labels = None
choice_labels = None
entity_labels = None entity_labels = None
entity_classification_labels = None entity_classification_labels = None
entity_pair_classification_labels = None entity_pair_classification_labels = None
...@@ -147,7 +156,9 @@ class LukeModelTester: ...@@ -147,7 +156,9 @@ class LukeModelTester:
if self.use_labels: if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size) entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size)
entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels) entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels)
...@@ -170,7 +181,8 @@ class LukeModelTester: ...@@ -170,7 +181,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels, token_labels,
choice_labels,
entity_labels, entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
...@@ -207,7 +219,8 @@ class LukeModelTester: ...@@ -207,7 +219,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels, token_labels,
choice_labels,
entity_labels, entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
...@@ -247,7 +260,8 @@ class LukeModelTester: ...@@ -247,7 +260,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels, token_labels,
choice_labels,
entity_labels, entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
...@@ -266,7 +280,7 @@ class LukeModelTester: ...@@ -266,7 +280,7 @@ class LukeModelTester:
entity_attention_mask=entity_attention_mask, entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids, entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids, entity_position_ids=entity_position_ids,
labels=labels, labels=token_labels,
entity_labels=entity_labels, entity_labels=entity_labels,
) )
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
...@@ -288,7 +302,8 @@ class LukeModelTester: ...@@ -288,7 +302,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels, token_labels,
choice_labels,
entity_labels, entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
...@@ -322,7 +337,8 @@ class LukeModelTester: ...@@ -322,7 +337,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels, token_labels,
choice_labels,
entity_labels, entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
...@@ -356,7 +372,8 @@ class LukeModelTester: ...@@ -356,7 +372,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels, token_labels,
choice_labels,
entity_labels, entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
...@@ -386,6 +403,156 @@ class LukeModelTester: ...@@ -386,6 +403,156 @@ class LukeModelTester:
result.logits.shape, (self.batch_size, self.entity_length, self.num_entity_span_classification_labels) result.logits.shape, (self.batch_size, self.entity_length, self.num_entity_span_classification_labels)
) )
def create_and_check_for_question_answering(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
model = LukeForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_for_sequence_classification(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_labels = self.num_labels
model = LukeForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=sequence_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_for_token_classification(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_labels = self.num_labels
model = LukeForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=token_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_for_multiple_choice(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_choices = self.num_choices
model = LukeForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_attention_mask = attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_entity_ids = entity_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_entity_token_type_ids = (
entity_token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
)
multiple_choice_entity_attention_mask = (
entity_attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
)
multiple_choice_entity_position_ids = (
entity_position_ids.unsqueeze(1).expand(-1, self.num_choices, -1, -1).contiguous()
)
result = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_attention_mask,
token_type_ids=multiple_choice_token_type_ids,
entity_ids=multiple_choice_entity_ids,
entity_attention_mask=multiple_choice_entity_attention_mask,
entity_token_type_ids=multiple_choice_entity_token_type_ids,
entity_position_ids=multiple_choice_entity_position_ids,
labels=choice_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -398,7 +565,8 @@ class LukeModelTester: ...@@ -398,7 +565,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels, token_labels,
choice_labels,
entity_labels, entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
...@@ -426,6 +594,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -426,6 +594,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
LukeForEntityClassification, LukeForEntityClassification,
LukeForEntityPairClassification, LukeForEntityPairClassification,
LukeForEntitySpanClassification, LukeForEntitySpanClassification,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeForMultipleChoice,
) )
if is_torch_available() if is_torch_available()
else () else ()
...@@ -436,7 +608,19 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -436,7 +608,19 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
test_head_masking = True test_head_masking = True
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
entity_inputs_dict = {k: v for k, v in inputs_dict.items() if k.startswith("entity")}
inputs_dict = {k: v for k, v in inputs_dict.items() if not k.startswith("entity")}
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if model_class == LukeForMultipleChoice:
entity_inputs_dict = {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if v.ndim == 2
else v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1, -1).contiguous()
for k, v in entity_inputs_dict.items()
}
inputs_dict.update(entity_inputs_dict)
if model_class == LukeForEntitySpanClassification: if model_class == LukeForEntitySpanClassification:
inputs_dict["entity_start_positions"] = torch.zeros( inputs_dict["entity_start_positions"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.entity_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.entity_length), dtype=torch.long, device=torch_device
...@@ -446,7 +630,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -446,7 +630,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
) )
if return_labels: if return_labels:
if model_class in (LukeForEntityClassification, LukeForEntityPairClassification): if model_class in (
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForSequenceClassification,
LukeForMultipleChoice,
):
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device self.model_tester.batch_size, dtype=torch.long, device=torch_device
) )
...@@ -456,6 +645,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -456,6 +645,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
dtype=torch.long, dtype=torch.long,
device=torch_device, device=torch_device,
) )
elif model_class == LukeForTokenClassification:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length),
dtype=torch.long,
device=torch_device,
)
elif model_class == LukeForMaskedLM: elif model_class == LukeForMaskedLM:
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), (self.model_tester.batch_size, self.model_tester.seq_length),
...@@ -496,6 +691,22 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -496,6 +691,22 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:]))) config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:])))
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
def test_for_entity_classification(self): def test_for_entity_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs) self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment