Unverified Commit 4d1ce396 authored by Ryokan RI's avatar Ryokan RI Committed by GitHub
Browse files

Debug LukeForMaskedLM (#17499)

* add a test for a word only input

* make LukeForMaskedLM work without entity inputs

* update test

* add LukeForMaskedLM to MODEL_FOR_MASKED_LM_MAPPING_NAMES

* restore pyproject.toml

* empty line at the end of pyproject.toml
parent 4390151b
...@@ -377,6 +377,7 @@ MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( ...@@ -377,6 +377,7 @@ MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
("ibert", "IBertForMaskedLM"), ("ibert", "IBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"), ("layoutlm", "LayoutLMForMaskedLM"),
("longformer", "LongformerForMaskedLM"), ("longformer", "LongformerForMaskedLM"),
("luke", "LukeForMaskedLM"),
("mbart", "MBartForConditionalGeneration"), ("mbart", "MBartForConditionalGeneration"),
("megatron-bert", "MegatronBertForMaskedLM"), ("megatron-bert", "MegatronBertForMaskedLM"),
("mobilebert", "MobileBertForMaskedLM"), ("mobilebert", "MobileBertForMaskedLM"),
......
...@@ -1229,13 +1229,15 @@ class LukeForMaskedLM(LukePreTrainedModel): ...@@ -1229,13 +1229,15 @@ class LukeForMaskedLM(LukePreTrainedModel):
loss = mlm_loss loss = mlm_loss
mep_loss = None mep_loss = None
entity_logits = self.entity_predictions(outputs.entity_last_hidden_state) entity_logits = None
if entity_labels is not None: if outputs.entity_last_hidden_state is not None:
mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1)) entity_logits = self.entity_predictions(outputs.entity_last_hidden_state)
if loss is None: if entity_labels is not None:
loss = mep_loss mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1))
else: if loss is None:
loss = loss + mep_loss loss = mep_loss
else:
loss = loss + mep_loss
if not return_dict: if not return_dict:
output = (logits, entity_logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions) output = (logits, entity_logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions)
......
...@@ -270,9 +270,12 @@ class LukeModelTester: ...@@ -270,9 +270,12 @@ class LukeModelTester:
entity_labels=entity_labels, entity_labels=entity_labels,
) )
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual( if entity_ids is not None:
result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size) self.parent.assertEqual(
) result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size)
)
else:
self.parent.assertIsNone(result.entity_logits)
def create_and_check_for_entity_classification( def create_and_check_for_entity_classification(
self, self,
...@@ -488,6 +491,11 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -488,6 +491,11 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_masked_lm_with_word_only(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs = (*config_and_inputs[:4], *((None,) * len(config_and_inputs[4:])))
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_entity_classification(self): def test_for_entity_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs) self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment