"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "f11fc7cf4732ff2393b92793bdf53967defbc2c7"
Unverified Commit 62098b93 authored by Ikuya Yamada's avatar Ikuya Yamada Committed by GitHub
Browse files

Adding fine-tuning models to LUKE (#18353)

* add LUKE models for downstream tasks

* add new LUKE models to docs

* fix typos

* remove commented lines

* exclude None items from tuple return values
parent 7b9e995b
...@@ -152,3 +152,23 @@ This model was contributed by [ikuyamada](https://huggingface.co/ikuyamada) and ...@@ -152,3 +152,23 @@ This model was contributed by [ikuyamada](https://huggingface.co/ikuyamada) and
[[autodoc]] LukeForEntitySpanClassification [[autodoc]] LukeForEntitySpanClassification
- forward - forward
## LukeForSequenceClassification
[[autodoc]] LukeForSequenceClassification
- forward
## LukeForMultipleChoice
[[autodoc]] LukeForMultipleChoice
- forward
## LukeForTokenClassification
[[autodoc]] LukeForTokenClassification
- forward
## LukeForQuestionAnswering
[[autodoc]] LukeForQuestionAnswering
- forward
...@@ -1363,6 +1363,10 @@ else: ...@@ -1363,6 +1363,10 @@ else:
"LukeForEntityClassification", "LukeForEntityClassification",
"LukeForEntityPairClassification", "LukeForEntityPairClassification",
"LukeForEntitySpanClassification", "LukeForEntitySpanClassification",
"LukeForMultipleChoice",
"LukeForQuestionAnswering",
"LukeForSequenceClassification",
"LukeForTokenClassification",
"LukeForMaskedLM", "LukeForMaskedLM",
"LukeModel", "LukeModel",
"LukePreTrainedModel", "LukePreTrainedModel",
...@@ -3953,6 +3957,10 @@ if TYPE_CHECKING: ...@@ -3953,6 +3957,10 @@ if TYPE_CHECKING:
LukeForEntityPairClassification, LukeForEntityPairClassification,
LukeForEntitySpanClassification, LukeForEntitySpanClassification,
LukeForMaskedLM, LukeForMaskedLM,
LukeForMultipleChoice,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeModel, LukeModel,
LukePreTrainedModel, LukePreTrainedModel,
) )
......
...@@ -170,6 +170,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -170,6 +170,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("ibert", "IBertForMaskedLM"), ("ibert", "IBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"), ("layoutlm", "LayoutLMForMaskedLM"),
("longformer", "LongformerForMaskedLM"), ("longformer", "LongformerForMaskedLM"),
("luke", "LukeForMaskedLM"),
("lxmert", "LxmertForPreTraining"), ("lxmert", "LxmertForPreTraining"),
("megatron-bert", "MegatronBertForPreTraining"), ("megatron-bert", "MegatronBertForPreTraining"),
("mobilebert", "MobileBertForPreTraining"), ("mobilebert", "MobileBertForPreTraining"),
...@@ -230,6 +231,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( ...@@ -230,6 +231,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("led", "LEDForConditionalGeneration"), ("led", "LEDForConditionalGeneration"),
("longformer", "LongformerForMaskedLM"), ("longformer", "LongformerForMaskedLM"),
("longt5", "LongT5ForConditionalGeneration"), ("longt5", "LongT5ForConditionalGeneration"),
("luke", "LukeForMaskedLM"),
("m2m_100", "M2M100ForConditionalGeneration"), ("m2m_100", "M2M100ForConditionalGeneration"),
("marian", "MarianMTModel"), ("marian", "MarianMTModel"),
("megatron-bert", "MegatronBertForCausalLM"), ("megatron-bert", "MegatronBertForCausalLM"),
...@@ -499,6 +501,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -499,6 +501,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("layoutlmv3", "LayoutLMv3ForSequenceClassification"), ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
("led", "LEDForSequenceClassification"), ("led", "LEDForSequenceClassification"),
("longformer", "LongformerForSequenceClassification"), ("longformer", "LongformerForSequenceClassification"),
("luke", "LukeForSequenceClassification"),
("mbart", "MBartForSequenceClassification"), ("mbart", "MBartForSequenceClassification"),
("megatron-bert", "MegatronBertForSequenceClassification"), ("megatron-bert", "MegatronBertForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"), ("mobilebert", "MobileBertForSequenceClassification"),
...@@ -551,6 +554,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ...@@ -551,6 +554,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
("led", "LEDForQuestionAnswering"), ("led", "LEDForQuestionAnswering"),
("longformer", "LongformerForQuestionAnswering"), ("longformer", "LongformerForQuestionAnswering"),
("luke", "LukeForQuestionAnswering"),
("lxmert", "LxmertForQuestionAnswering"), ("lxmert", "LxmertForQuestionAnswering"),
("mbart", "MBartForQuestionAnswering"), ("mbart", "MBartForQuestionAnswering"),
("megatron-bert", "MegatronBertForQuestionAnswering"), ("megatron-bert", "MegatronBertForQuestionAnswering"),
...@@ -611,6 +615,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -611,6 +615,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("layoutlmv2", "LayoutLMv2ForTokenClassification"), ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
("layoutlmv3", "LayoutLMv3ForTokenClassification"), ("layoutlmv3", "LayoutLMv3ForTokenClassification"),
("longformer", "LongformerForTokenClassification"), ("longformer", "LongformerForTokenClassification"),
("luke", "LukeForTokenClassification"),
("megatron-bert", "MegatronBertForTokenClassification"), ("megatron-bert", "MegatronBertForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"),
("mpnet", "MPNetForTokenClassification"), ("mpnet", "MPNetForTokenClassification"),
...@@ -647,6 +652,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( ...@@ -647,6 +652,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
("funnel", "FunnelForMultipleChoice"), ("funnel", "FunnelForMultipleChoice"),
("ibert", "IBertForMultipleChoice"), ("ibert", "IBertForMultipleChoice"),
("longformer", "LongformerForMultipleChoice"), ("longformer", "LongformerForMultipleChoice"),
("luke", "LukeForMultipleChoice"),
("megatron-bert", "MegatronBertForMultipleChoice"), ("megatron-bert", "MegatronBertForMultipleChoice"),
("mobilebert", "MobileBertForMultipleChoice"), ("mobilebert", "MobileBertForMultipleChoice"),
("mpnet", "MPNetForMultipleChoice"), ("mpnet", "MPNetForMultipleChoice"),
......
...@@ -37,6 +37,10 @@ else: ...@@ -37,6 +37,10 @@ else:
"LukeForEntityClassification", "LukeForEntityClassification",
"LukeForEntityPairClassification", "LukeForEntityPairClassification",
"LukeForEntitySpanClassification", "LukeForEntitySpanClassification",
"LukeForMultipleChoice",
"LukeForQuestionAnswering",
"LukeForSequenceClassification",
"LukeForTokenClassification",
"LukeForMaskedLM", "LukeForMaskedLM",
"LukeModel", "LukeModel",
"LukePreTrainedModel", "LukePreTrainedModel",
...@@ -59,6 +63,10 @@ if TYPE_CHECKING: ...@@ -59,6 +63,10 @@ if TYPE_CHECKING:
LukeForEntityPairClassification, LukeForEntityPairClassification,
LukeForEntitySpanClassification, LukeForEntitySpanClassification,
LukeForMaskedLM, LukeForMaskedLM,
LukeForMultipleChoice,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeModel, LukeModel,
LukePreTrainedModel, LukePreTrainedModel,
) )
......
...@@ -74,6 +74,8 @@ class LukeConfig(PretrainedConfig): ...@@ -74,6 +74,8 @@ class LukeConfig(PretrainedConfig):
Whether or not the model should use the entity-aware self-attention mechanism proposed in [LUKE: Deep Whether or not the model should use the entity-aware self-attention mechanism proposed in [LUKE: Deep
Contextualized Entity Representations with Entity-aware Self-attention (Yamada et Contextualized Entity Representations with Entity-aware Self-attention (Yamada et
al.)](https://arxiv.org/abs/2010.01057). al.)](https://arxiv.org/abs/2010.01057).
classifier_dropout (`float`, *optional*):
The dropout ratio for the classification head.
Examples: Examples:
...@@ -108,6 +110,7 @@ class LukeConfig(PretrainedConfig): ...@@ -108,6 +110,7 @@ class LukeConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
use_entity_aware_attention=True, use_entity_aware_attention=True,
classifier_dropout=None,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
...@@ -131,3 +134,4 @@ class LukeConfig(PretrainedConfig): ...@@ -131,3 +134,4 @@ class LukeConfig(PretrainedConfig):
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.use_entity_aware_attention = use_entity_aware_attention self.use_entity_aware_attention = use_entity_aware_attention
self.classifier_dropout = classifier_dropout
...@@ -2736,6 +2736,34 @@ class LukeForMaskedLM(metaclass=DummyObject): ...@@ -2736,6 +2736,34 @@ class LukeForMaskedLM(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class LukeForMultipleChoice(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeModel(metaclass=DummyObject): class LukeModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -30,6 +30,10 @@ if is_torch_available(): ...@@ -30,6 +30,10 @@ if is_torch_available():
LukeForEntityPairClassification, LukeForEntityPairClassification,
LukeForEntitySpanClassification, LukeForEntitySpanClassification,
LukeForMaskedLM, LukeForMaskedLM,
LukeForMultipleChoice,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeModel, LukeModel,
LukeTokenizer, LukeTokenizer,
) )
...@@ -66,6 +70,8 @@ class LukeModelTester: ...@@ -66,6 +70,8 @@ class LukeModelTester:
type_vocab_size=16, type_vocab_size=16,
type_sequence_label_size=2, type_sequence_label_size=2,
initializer_range=0.02, initializer_range=0.02,
num_labels=3,
num_choices=4,
num_entity_classification_labels=9, num_entity_classification_labels=9,
num_entity_pair_classification_labels=6, num_entity_pair_classification_labels=6,
num_entity_span_classification_labels=4, num_entity_span_classification_labels=4,
...@@ -99,6 +105,8 @@ class LukeModelTester: ...@@ -99,6 +105,8 @@ class LukeModelTester:
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.num_entity_classification_labels = num_entity_classification_labels self.num_entity_classification_labels = num_entity_classification_labels
self.num_entity_pair_classification_labels = num_entity_pair_classification_labels self.num_entity_pair_classification_labels = num_entity_pair_classification_labels
self.num_entity_span_classification_labels = num_entity_span_classification_labels self.num_entity_span_classification_labels = num_entity_span_classification_labels
...@@ -139,7 +147,8 @@ class LukeModelTester: ...@@ -139,7 +147,8 @@ class LukeModelTester:
) )
sequence_labels = None sequence_labels = None
labels = None token_labels = None
choice_labels = None
entity_labels = None entity_labels = None
entity_classification_labels = None entity_classification_labels = None
entity_pair_classification_labels = None entity_pair_classification_labels = None
...@@ -147,7 +156,9 @@ class LukeModelTester: ...@@ -147,7 +156,9 @@ class LukeModelTester:
if self.use_labels: if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size) entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size)
entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels) entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels)
...@@ -170,7 +181,8 @@ class LukeModelTester: ...@@ -170,7 +181,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels, token_labels,
choice_labels,
entity_labels, entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
...@@ -207,7 +219,8 @@ class LukeModelTester: ...@@ -207,7 +219,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels, token_labels,
choice_labels,
entity_labels, entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
...@@ -247,7 +260,8 @@ class LukeModelTester: ...@@ -247,7 +260,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels, token_labels,
choice_labels,
entity_labels, entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
...@@ -266,7 +280,7 @@ class LukeModelTester: ...@@ -266,7 +280,7 @@ class LukeModelTester:
entity_attention_mask=entity_attention_mask, entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids, entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids, entity_position_ids=entity_position_ids,
labels=labels, labels=token_labels,
entity_labels=entity_labels, entity_labels=entity_labels,
) )
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
...@@ -288,7 +302,8 @@ class LukeModelTester: ...@@ -288,7 +302,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels, token_labels,
choice_labels,
entity_labels, entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
...@@ -322,7 +337,8 @@ class LukeModelTester: ...@@ -322,7 +337,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels, token_labels,
choice_labels,
entity_labels, entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
...@@ -356,7 +372,8 @@ class LukeModelTester: ...@@ -356,7 +372,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels, token_labels,
choice_labels,
entity_labels, entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
...@@ -386,6 +403,156 @@ class LukeModelTester: ...@@ -386,6 +403,156 @@ class LukeModelTester:
result.logits.shape, (self.batch_size, self.entity_length, self.num_entity_span_classification_labels) result.logits.shape, (self.batch_size, self.entity_length, self.num_entity_span_classification_labels)
) )
def create_and_check_for_question_answering(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
model = LukeForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_for_sequence_classification(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_labels = self.num_labels
model = LukeForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=sequence_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_for_token_classification(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_labels = self.num_labels
model = LukeForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=token_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_for_multiple_choice(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_choices = self.num_choices
model = LukeForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_attention_mask = attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_entity_ids = entity_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_entity_token_type_ids = (
entity_token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
)
multiple_choice_entity_attention_mask = (
entity_attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
)
multiple_choice_entity_position_ids = (
entity_position_ids.unsqueeze(1).expand(-1, self.num_choices, -1, -1).contiguous()
)
result = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_attention_mask,
token_type_ids=multiple_choice_token_type_ids,
entity_ids=multiple_choice_entity_ids,
entity_attention_mask=multiple_choice_entity_attention_mask,
entity_token_type_ids=multiple_choice_entity_token_type_ids,
entity_position_ids=multiple_choice_entity_position_ids,
labels=choice_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -398,7 +565,8 @@ class LukeModelTester: ...@@ -398,7 +565,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels, token_labels,
choice_labels,
entity_labels, entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
...@@ -426,6 +594,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -426,6 +594,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
LukeForEntityClassification, LukeForEntityClassification,
LukeForEntityPairClassification, LukeForEntityPairClassification,
LukeForEntitySpanClassification, LukeForEntitySpanClassification,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeForMultipleChoice,
) )
if is_torch_available() if is_torch_available()
else () else ()
...@@ -436,7 +608,19 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -436,7 +608,19 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
test_head_masking = True test_head_masking = True
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
entity_inputs_dict = {k: v for k, v in inputs_dict.items() if k.startswith("entity")}
inputs_dict = {k: v for k, v in inputs_dict.items() if not k.startswith("entity")}
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if model_class == LukeForMultipleChoice:
entity_inputs_dict = {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if v.ndim == 2
else v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1, -1).contiguous()
for k, v in entity_inputs_dict.items()
}
inputs_dict.update(entity_inputs_dict)
if model_class == LukeForEntitySpanClassification: if model_class == LukeForEntitySpanClassification:
inputs_dict["entity_start_positions"] = torch.zeros( inputs_dict["entity_start_positions"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.entity_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.entity_length), dtype=torch.long, device=torch_device
...@@ -446,7 +630,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -446,7 +630,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
) )
if return_labels: if return_labels:
if model_class in (LukeForEntityClassification, LukeForEntityPairClassification): if model_class in (
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForSequenceClassification,
LukeForMultipleChoice,
):
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device self.model_tester.batch_size, dtype=torch.long, device=torch_device
) )
...@@ -456,6 +645,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -456,6 +645,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
dtype=torch.long, dtype=torch.long,
device=torch_device, device=torch_device,
) )
elif model_class == LukeForTokenClassification:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length),
dtype=torch.long,
device=torch_device,
)
elif model_class == LukeForMaskedLM: elif model_class == LukeForMaskedLM:
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), (self.model_tester.batch_size, self.model_tester.seq_length),
...@@ -496,6 +691,22 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -496,6 +691,22 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:]))) config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:])))
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
def test_for_entity_classification(self): def test_for_entity_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs) self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment