Unverified Commit 62098b93 authored by Ikuya Yamada's avatar Ikuya Yamada Committed by GitHub
Browse files

Adding fine-tuning models to LUKE (#18353)

* add LUKE models for downstream tasks

* add new LUKE models to docs

* fix typos

* remove commented lines

* exclude None items from tuple return values
parent 7b9e995b
......@@ -152,3 +152,23 @@ This model was contributed by [ikuyamada](https://huggingface.co/ikuyamada) and
[[autodoc]] LukeForEntitySpanClassification
- forward
## LukeForSequenceClassification
[[autodoc]] LukeForSequenceClassification
- forward
## LukeForMultipleChoice
[[autodoc]] LukeForMultipleChoice
- forward
## LukeForTokenClassification
[[autodoc]] LukeForTokenClassification
- forward
## LukeForQuestionAnswering
[[autodoc]] LukeForQuestionAnswering
- forward
......@@ -1363,6 +1363,10 @@ else:
"LukeForEntityClassification",
"LukeForEntityPairClassification",
"LukeForEntitySpanClassification",
"LukeForMultipleChoice",
"LukeForQuestionAnswering",
"LukeForSequenceClassification",
"LukeForTokenClassification",
"LukeForMaskedLM",
"LukeModel",
"LukePreTrainedModel",
......@@ -3953,6 +3957,10 @@ if TYPE_CHECKING:
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForMaskedLM,
LukeForMultipleChoice,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeModel,
LukePreTrainedModel,
)
......
......@@ -170,6 +170,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("ibert", "IBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"),
("longformer", "LongformerForMaskedLM"),
("luke", "LukeForMaskedLM"),
("lxmert", "LxmertForPreTraining"),
("megatron-bert", "MegatronBertForPreTraining"),
("mobilebert", "MobileBertForPreTraining"),
......@@ -230,6 +231,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("led", "LEDForConditionalGeneration"),
("longformer", "LongformerForMaskedLM"),
("longt5", "LongT5ForConditionalGeneration"),
("luke", "LukeForMaskedLM"),
("m2m_100", "M2M100ForConditionalGeneration"),
("marian", "MarianMTModel"),
("megatron-bert", "MegatronBertForCausalLM"),
......@@ -499,6 +501,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
("led", "LEDForSequenceClassification"),
("longformer", "LongformerForSequenceClassification"),
("luke", "LukeForSequenceClassification"),
("mbart", "MBartForSequenceClassification"),
("megatron-bert", "MegatronBertForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"),
......@@ -551,6 +554,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
("led", "LEDForQuestionAnswering"),
("longformer", "LongformerForQuestionAnswering"),
("luke", "LukeForQuestionAnswering"),
("lxmert", "LxmertForQuestionAnswering"),
("mbart", "MBartForQuestionAnswering"),
("megatron-bert", "MegatronBertForQuestionAnswering"),
......@@ -611,6 +615,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
("layoutlmv3", "LayoutLMv3ForTokenClassification"),
("longformer", "LongformerForTokenClassification"),
("luke", "LukeForTokenClassification"),
("megatron-bert", "MegatronBertForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"),
("mpnet", "MPNetForTokenClassification"),
......@@ -647,6 +652,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
("funnel", "FunnelForMultipleChoice"),
("ibert", "IBertForMultipleChoice"),
("longformer", "LongformerForMultipleChoice"),
("luke", "LukeForMultipleChoice"),
("megatron-bert", "MegatronBertForMultipleChoice"),
("mobilebert", "MobileBertForMultipleChoice"),
("mpnet", "MPNetForMultipleChoice"),
......
......@@ -37,6 +37,10 @@ else:
"LukeForEntityClassification",
"LukeForEntityPairClassification",
"LukeForEntitySpanClassification",
"LukeForMultipleChoice",
"LukeForQuestionAnswering",
"LukeForSequenceClassification",
"LukeForTokenClassification",
"LukeForMaskedLM",
"LukeModel",
"LukePreTrainedModel",
......@@ -59,6 +63,10 @@ if TYPE_CHECKING:
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForMaskedLM,
LukeForMultipleChoice,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeModel,
LukePreTrainedModel,
)
......
......@@ -74,6 +74,8 @@ class LukeConfig(PretrainedConfig):
Whether or not the model should use the entity-aware self-attention mechanism proposed in [LUKE: Deep
Contextualized Entity Representations with Entity-aware Self-attention (Yamada et
al.)](https://arxiv.org/abs/2010.01057).
classifier_dropout (`float`, *optional*):
The dropout ratio for the classification head.
Examples:
......@@ -108,6 +110,7 @@ class LukeConfig(PretrainedConfig):
initializer_range=0.02,
layer_norm_eps=1e-12,
use_entity_aware_attention=True,
classifier_dropout=None,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
......@@ -131,3 +134,4 @@ class LukeConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.use_entity_aware_attention = use_entity_aware_attention
self.classifier_dropout = classifier_dropout
......@@ -2736,6 +2736,34 @@ class LukeForMaskedLM(metaclass=DummyObject):
requires_backends(self, ["torch"])
class LukeForMultipleChoice(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeForQuestionAnswering(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeModel(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -30,6 +30,10 @@ if is_torch_available():
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForMaskedLM,
LukeForMultipleChoice,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeModel,
LukeTokenizer,
)
......@@ -66,6 +70,8 @@ class LukeModelTester:
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
num_entity_classification_labels=9,
num_entity_pair_classification_labels=6,
num_entity_span_classification_labels=4,
......@@ -99,6 +105,8 @@ class LukeModelTester:
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.num_entity_classification_labels = num_entity_classification_labels
self.num_entity_pair_classification_labels = num_entity_pair_classification_labels
self.num_entity_span_classification_labels = num_entity_span_classification_labels
......@@ -139,7 +147,8 @@ class LukeModelTester:
)
sequence_labels = None
labels = None
token_labels = None
choice_labels = None
entity_labels = None
entity_classification_labels = None
entity_pair_classification_labels = None
......@@ -147,7 +156,9 @@ class LukeModelTester:
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size)
entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels)
......@@ -170,7 +181,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
......@@ -207,7 +219,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
......@@ -247,7 +260,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
......@@ -266,7 +280,7 @@ class LukeModelTester:
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=labels,
labels=token_labels,
entity_labels=entity_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
......@@ -288,7 +302,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
......@@ -322,7 +337,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
......@@ -356,7 +372,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
......@@ -386,6 +403,156 @@ class LukeModelTester:
result.logits.shape, (self.batch_size, self.entity_length, self.num_entity_span_classification_labels)
)
def create_and_check_for_question_answering(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
model = LukeForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_for_sequence_classification(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_labels = self.num_labels
model = LukeForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=sequence_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_for_token_classification(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_labels = self.num_labels
model = LukeForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=token_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_for_multiple_choice(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_choices = self.num_choices
model = LukeForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_attention_mask = attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_entity_ids = entity_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_entity_token_type_ids = (
entity_token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
)
multiple_choice_entity_attention_mask = (
entity_attention_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
)
multiple_choice_entity_position_ids = (
entity_position_ids.unsqueeze(1).expand(-1, self.num_choices, -1, -1).contiguous()
)
result = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_attention_mask,
token_type_ids=multiple_choice_token_type_ids,
entity_ids=multiple_choice_entity_ids,
entity_attention_mask=multiple_choice_entity_attention_mask,
entity_token_type_ids=multiple_choice_entity_token_type_ids,
entity_position_ids=multiple_choice_entity_position_ids,
labels=choice_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
......@@ -398,7 +565,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
token_labels,
choice_labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
......@@ -426,6 +594,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForQuestionAnswering,
LukeForSequenceClassification,
LukeForTokenClassification,
LukeForMultipleChoice,
)
if is_torch_available()
else ()
......@@ -436,7 +608,19 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
test_head_masking = True
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
entity_inputs_dict = {k: v for k, v in inputs_dict.items() if k.startswith("entity")}
inputs_dict = {k: v for k, v in inputs_dict.items() if not k.startswith("entity")}
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if model_class == LukeForMultipleChoice:
entity_inputs_dict = {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if v.ndim == 2
else v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1, -1).contiguous()
for k, v in entity_inputs_dict.items()
}
inputs_dict.update(entity_inputs_dict)
if model_class == LukeForEntitySpanClassification:
inputs_dict["entity_start_positions"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.entity_length), dtype=torch.long, device=torch_device
......@@ -446,7 +630,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
)
if return_labels:
if model_class in (LukeForEntityClassification, LukeForEntityPairClassification):
if model_class in (
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForSequenceClassification,
LukeForMultipleChoice,
):
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
......@@ -456,6 +645,12 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
dtype=torch.long,
device=torch_device,
)
elif model_class == LukeForTokenClassification:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length),
dtype=torch.long,
device=torch_device,
)
elif model_class == LukeForMaskedLM:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length),
......@@ -496,6 +691,22 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:])))
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
def test_for_entity_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment