Unverified Commit 30646a0a authored by Ryokan RI's avatar Ryokan RI Committed by GitHub
Browse files

Add mLUKE (#14640)

* implement MLukeTokenizer and LukeForMaskedLM

* update tests

* update docs

* add LukeForMaskedLM to check_repo.py

* update README

* fix test and specify the entity pad id in tokenization_(m)luke

* fix EntityPredictionHeadTransform
parent 4cdb67ca
...@@ -29,6 +29,7 @@ if is_torch_available(): ...@@ -29,6 +29,7 @@ if is_torch_available():
LukeForEntityClassification, LukeForEntityClassification,
LukeForEntityPairClassification, LukeForEntityPairClassification,
LukeForEntitySpanClassification, LukeForEntitySpanClassification,
LukeForMaskedLM,
LukeModel, LukeModel,
LukeTokenizer, LukeTokenizer,
) )
...@@ -138,12 +139,17 @@ class LukeModelTester: ...@@ -138,12 +139,17 @@ class LukeModelTester:
) )
sequence_labels = None sequence_labels = None
labels = None
entity_labels = None
entity_classification_labels = None entity_classification_labels = None
entity_pair_classification_labels = None entity_pair_classification_labels = None
entity_span_classification_labels = None entity_span_classification_labels = None
if self.use_labels: if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size)
entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels) entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels)
entity_pair_classification_labels = ids_tensor( entity_pair_classification_labels = ids_tensor(
[self.batch_size], self.num_entity_pair_classification_labels [self.batch_size], self.num_entity_pair_classification_labels
...@@ -164,6 +170,8 @@ class LukeModelTester: ...@@ -164,6 +170,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels,
entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
entity_span_classification_labels, entity_span_classification_labels,
...@@ -199,6 +207,8 @@ class LukeModelTester: ...@@ -199,6 +207,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels,
entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
entity_span_classification_labels, entity_span_classification_labels,
...@@ -226,6 +236,44 @@ class LukeModelTester: ...@@ -226,6 +236,44 @@ class LukeModelTester:
result = model(input_ids) result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_masked_lm(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_labels = self.num_entity_classification_labels
model = LukeForMaskedLM(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=labels,
entity_labels=entity_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual(
result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size)
)
def create_and_check_for_entity_classification( def create_and_check_for_entity_classification(
self, self,
config, config,
...@@ -237,6 +285,8 @@ class LukeModelTester: ...@@ -237,6 +285,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels,
entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
entity_span_classification_labels, entity_span_classification_labels,
...@@ -269,6 +319,8 @@ class LukeModelTester: ...@@ -269,6 +319,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels,
entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
entity_span_classification_labels, entity_span_classification_labels,
...@@ -301,6 +353,8 @@ class LukeModelTester: ...@@ -301,6 +353,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels,
entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
entity_span_classification_labels, entity_span_classification_labels,
...@@ -341,6 +395,8 @@ class LukeModelTester: ...@@ -341,6 +395,8 @@ class LukeModelTester:
entity_token_type_ids, entity_token_type_ids,
entity_position_ids, entity_position_ids,
sequence_labels, sequence_labels,
labels,
entity_labels,
entity_classification_labels, entity_classification_labels,
entity_pair_classification_labels, entity_pair_classification_labels,
entity_span_classification_labels, entity_span_classification_labels,
...@@ -363,6 +419,7 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -363,6 +419,7 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
LukeModel, LukeModel,
LukeForMaskedLM,
LukeForEntityClassification, LukeForEntityClassification,
LukeForEntityPairClassification, LukeForEntityPairClassification,
LukeForEntitySpanClassification, LukeForEntitySpanClassification,
...@@ -396,6 +453,18 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -396,6 +453,18 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
dtype=torch.long, dtype=torch.long,
device=torch_device, device=torch_device,
) )
elif model_class == LukeForMaskedLM:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length),
dtype=torch.long,
device=torch_device,
)
inputs_dict["entity_labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.entity_length),
dtype=torch.long,
device=torch_device,
)
return inputs_dict return inputs_dict
def setUp(self): def setUp(self):
...@@ -415,6 +484,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -415,6 +484,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
model = LukeModel.from_pretrained(model_name) model = LukeModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_entity_classification(self): def test_for_entity_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs) self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)
......
...@@ -23,7 +23,7 @@ from transformers.testing_utils import require_torch, slow ...@@ -23,7 +23,7 @@ from transformers.testing_utils import require_torch, slow
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
class Luke(TokenizerTesterMixin, unittest.TestCase): class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = LukeTokenizer tokenizer_class = LukeTokenizer
test_rust_tokenizer = False test_rust_tokenizer = False
from_pretrained_kwargs = {"cls_token": "<s>"} from_pretrained_kwargs = {"cls_token": "<s>"}
...@@ -79,8 +79,8 @@ class Luke(TokenizerTesterMixin, unittest.TestCase): ...@@ -79,8 +79,8 @@ class Luke(TokenizerTesterMixin, unittest.TestCase):
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
assert encoded_sentence == encoded_text_from_decode self.assertEqual(encoded_sentence, encoded_text_from_decode)
assert encoded_pair == encoded_pair_from_decode self.assertEqual(encoded_pair, encoded_pair_from_decode)
def get_clean_sequence(self, tokenizer, max_length=20) -> Tuple[str, list]: def get_clean_sequence(self, tokenizer, max_length=20) -> Tuple[str, list]:
txt = "Beyonce lives in Los Angeles" txt = "Beyonce lives in Los Angeles"
...@@ -159,6 +159,81 @@ class Luke(TokenizerTesterMixin, unittest.TestCase): ...@@ -159,6 +159,81 @@ class Luke(TokenizerTesterMixin, unittest.TestCase):
tokens_p_str, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"] tokens_p_str, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"]
) )
def test_padding_entity_inputs(self):
tokenizer = self.get_tokenizer()
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
span = (15, 34)
pad_id = tokenizer.entity_vocab["[PAD]"]
mask_id = tokenizer.entity_vocab["[MASK]"]
encoding = tokenizer([sentence, sentence], entity_spans=[[span], [span, span]], padding=True)
self.assertEqual(encoding["entity_ids"], [[mask_id, pad_id], [mask_id, mask_id]])
# test with a sentence with no entity
encoding = tokenizer([sentence, sentence], entity_spans=[[], [span, span]], padding=True)
self.assertEqual(encoding["entity_ids"], [[pad_id, pad_id], [mask_id, mask_id]])
def test_if_tokenize_single_text_raise_error_with_invalid_inputs(self):
tokenizer = self.get_tokenizer()
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
spans = [(15, 34)]
entities = ["East Asian language"]
with self.assertRaises(ValueError):
tokenizer(sentence, entities=tuple(entities), entity_spans=spans)
with self.assertRaises(ValueError):
tokenizer(sentence, entities=entities, entity_spans=tuple(spans))
with self.assertRaises(ValueError):
tokenizer(sentence, entities=[0], entity_spans=spans)
with self.assertRaises(ValueError):
tokenizer(sentence, entities=entities, entity_spans=[0])
with self.assertRaises(ValueError):
tokenizer(sentence, entities=entities, entity_spans=spans + [(0, 9)])
def test_if_tokenize_entity_classification_raise_error_with_invalid_inputs(self):
tokenizer = self.get_tokenizer(task="entity_classification")
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
span = (15, 34)
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[])
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[span, span])
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[0])
def test_if_tokenize_entity_pair_classification_raise_error_with_invalid_inputs(self):
tokenizer = self.get_tokenizer(task="entity_pair_classification")
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
# head and tail information
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[])
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[0, 0])
def test_if_tokenize_entity_span_classification_raise_error_with_invalid_inputs(self):
tokenizer = self.get_tokenizer(task="entity_span_classification")
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[])
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[0, 0, 0])
@require_torch @require_torch
class LukeTokenizerIntegrationTests(unittest.TestCase): class LukeTokenizerIntegrationTests(unittest.TestCase):
......
This diff is collapsed.
...@@ -116,6 +116,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -116,6 +116,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"DPRReader", "DPRReader",
"FlaubertForQuestionAnswering", "FlaubertForQuestionAnswering",
"GPT2DoubleHeadsModel", "GPT2DoubleHeadsModel",
"LukeForMaskedLM",
"LukeForEntityClassification", "LukeForEntityClassification",
"LukeForEntityPairClassification", "LukeForEntityPairClassification",
"LukeForEntitySpanClassification", "LukeForEntitySpanClassification",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment